From a6c206ea66146b374704a74321271156b8d04c04 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 30 Dec 2024 16:40:36 -0800 Subject: [PATCH 01/12] [bugfix] fix prompt_adapter interleaved_content_convert_to_raw (#696) # What does this PR do? - fix interleaved_content_convert_to_raw in prompt_adapter to correctly convert ImageContentItem to RawMediaItem with raw data bytes ## Test Plan ``` torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py ``` **Before** image **After** image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../utils/inference/prompt_adapter.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index f7d2cd84e..ed0cabe1c 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, TextContentItem, - URL, ) from llama_stack.apis.inference import ( @@ -117,27 +116,31 @@ async def interleaved_content_convert_to_raw( elif isinstance(c, TextContentItem): return RawTextItem(text=c.text) elif isinstance(c, ImageContentItem): - # load image and return PIL version - img = c.data - if isinstance(img, URL): - if img.uri.startswith("data"): - match = re.match(r"data:image/(\w+);base64,(.+)", img.uri) + if c.url: + # Load image bytes from URL + if c.url.uri.startswith("data"): + match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri) if not match: - raise ValueError("Invalid data URL format") + raise ValueError( + f"Invalid data URL format, {c.url.uri[:40]}..." + ) _, image_data = match.groups() data = base64.b64decode(image_data) - elif img.uri.startswith("file://"): - path = img.uri[len("file://") :] + elif c.url.uri.startswith("file://"): + path = c.url.uri[len("file://") :] with open(path, "rb") as f: data = f.read() # type: ignore - elif img.uri.startswith("http"): + elif c.url.uri.startswith("http"): async with httpx.AsyncClient() as client: - response = await client.get(img.uri) + response = await client.get(c.url.uri) data = response.content else: raise ValueError("Unsupported URL type") - else: + elif c.data: data = c.data + else: + raise ValueError("No data or URL provided") + return RawMediaItem(data=data) else: raise ValueError(f"Unsupported content type: {type(c)}") From eee25db11ddc77af64a52adbd7de985cd20c01b7 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 2 Jan 2025 11:03:30 -0600 Subject: [PATCH 02/12] Add missing "inline::" prefix for providers in building_distro.md (#702) This fixes the following errors: ``` ValueError: Provider `meta-reference` is not available for API `agents` ValueError: Provider `meta-reference` is not available for API `telemetry` ``` --- docs/source/distributions/building_distro.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 67d39159c..cc94fa9db 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -338,8 +338,8 @@ distribution_spec: inference: remote::ollama memory: inline::faiss safety: inline::llama-guard - agents: meta-reference - telemetry: meta-reference + agents: inline::meta-reference + telemetry: inline::meta-reference image_type: conda ``` From c1987d6143f22574ce83ee134ec282fcb9589715 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 2 Jan 2025 11:04:07 -0600 Subject: [PATCH 03/12] Fix failing flake8 E226 check (#701) This fixes the pre-commit check when running locally (not sure why this was not caught on CI check): ``` > pre-commit run --show-diff-on-failure --color=always --all-files trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Failed - hook id: flake8 - exit code: 1 llama_stack/distribution/ui/page/evaluations/app_eval.py:132:65: E226 missing whitespace around arithmetic operator llama_stack/distribution/ui/page/evaluations/native_eval.py:235:61: E226 missing whitespace around arithmetic operator llama_stack/providers/utils/telemetry/trace_protocol.py:56:78: E226 missing whitespace around arithmetic operator ``` Signed-off-by: Yuan Tang --- llama_stack/distribution/ui/page/evaluations/app_eval.py | 2 +- llama_stack/distribution/ui/page/evaluations/native_eval.py | 2 +- llama_stack/providers/utils/telemetry/trace_protocol.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 5ec47ed45..a9dd50a04 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -129,7 +129,7 @@ def application_evaluation_page(): # Display current row results using separate containers progress_text_container.write( - f"Expand to see current processed result ({i+1}/{len(rows)})" + f"Expand to see current processed result ({i + 1} / {len(rows)})" ) results_container.json( score_res.to_json(), diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index b8cc8bfa6..2cbc8d63e 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -232,7 +232,7 @@ def run_evaluation_3(): output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) progress_text_container.write( - f"Expand to see current processed result ({i+1}/{len(rows)})" + f"Expand to see current processed result ({i + 1} / {len(rows)})" ) results_container.json(eval_res, expanded=2) diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 31897c0ae..38a56fdac 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -53,7 +53,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]: combined_args = {} for i, arg in enumerate(args): param_name = ( - param_names[i] if i < len(param_names) else f"position_{i+1}" + param_names[i] if i < len(param_names) else f"position_{i + 1}" ) combined_args[param_name] = serialize_value(arg) for k, v in kwargs.items(): From 8146dce11e290fd0e9925f46df8766dfe218a421 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 2 Jan 2025 11:04:29 -0600 Subject: [PATCH 04/12] Add missing newlines before printing the Dockerfile content (#700) Before: ``` Dockerfile created successfully in /tmp/tmp.qyMdb0vI8X/DockerfileFROM python:3.10-slim WORKDIR /app RUN apt-get update && apt-get install -y iputils-ping net-tools iproute2 dnsutils telnet curl wget telnet procps psmisc lsof traceroute bubblewrap && rm -rf /var/lib/apt/lists/* ``` After: ``` Dockerfile created successfully in /tmp/tmp.qyMdb0vI8X/Dockerfile FROM python:3.10-slim WORKDIR /app RUN apt-get update && apt-get install -y iputils-ping net-tools iproute2 dnsutils telnet curl wget telnet procps psmisc lsof traceroute bubblewrap && rm -rf /var/lib/apt/lists/* ``` Signed-off-by: Yuan Tang --- llama_stack/distribution/build_container.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index a9aee8f14..49e65b8cb 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -126,7 +126,7 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--templat EOF -printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile" +printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile\n\n" cat $TEMP_DIR/Dockerfile printf "\n" From 5d7b61133657a92e3584fbcefc744ddd333d743f Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Fri, 3 Jan 2025 04:05:51 +1100 Subject: [PATCH 05/12] Add JSON structured outputs to Ollama Provider (#680) # What does this PR do? Addresses issue #679 - Adds support for the response_format field for chat completions and completions so users can get their outputs in JSON ## Test Plan
Integration tests `pytest llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output -k ollama -s -v` ```python llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-ollama] PASSED ================================== 2 passed, 18 deselected, 3 warnings in 41.41s ================================== ```
Manual Tests ``` export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export OLLAMA_INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=5000 ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m llama stack build --template ollama --image-type conda llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env OLLAMA_URL=http://localhost:11434 ``` ```python client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}") MODEL_ID=meta-llama/Llama-3.2-3B-Instruct prompt =f""" Create a step by step plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French. You have 3 different operations you can perform. You can create a file, update a file, or delete a file. Limit your step by step plan to only these operations per step. Don't create more than 10 steps. Please ensure there's a README.md file in the root of the codebase that describes the codebase and how to run it. Please ensure there's a requirements.txt file in the root of the codebase that describes the dependencies of the codebase. """ response = client.inference.chat_completion( model_id=MODEL_ID, messages=[ {"role": "user", "content": prompt}, ], sampling_params={ "max_tokens": 200000, }, response_format={ "type": "json_schema", "json_schema": { "$schema": "http://json-schema.org/draft-07/schema#", "title": "Plan", "description": f"A plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French.", "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "string" } } }, "required": ["steps"], "additionalProperties": False, } }, stream=True, ) content = "" for chunk in response: if chunk.event.delta: print(chunk.event.delta, end="", flush=True) content += chunk.event.delta try: plan = json.loads(content) print(plan) except Exception as e: print(f"Error parsing plan into JSON: {e}") plan = {"steps": []} ``` Outputs: ```json { "steps": [ "Update the requirements.txt file to include the updated dependencies specified in the peer's feedback, including the Google Cloud Translation API key.", "Update the app.py file to address the code smells and incorporate the suggested improvements, such as handling errors and exceptions, initializing the Translator object correctly, adding input validation, using type hints and docstrings, and removing unnecessary logging statements.", "Create a README.md file that describes the codebase and how to run it.", "Ensure the README.md file is up-to-date and accurate.", "Update the requirements.txt file to reflect any additional dependencies specified by the peer's feedback.", "Add documentation for each function in the app.py file using docstrings.", "Implement logging statements throughout the app.py file to monitor application execution.", "Test the API endpoint to ensure it correctly translates text from English to French and handles errors properly.", "Refactor the code to follow PEP 8 style guidelines and ensure consistency in naming conventions, indentation, and spacing.", "Create a new folder for logs and add a logging configuration file (e.g., logconfig.json) that specifies the logging level and output destination.", "Deploy the web server on a production environment (e.g., AWS Elastic Beanstalk or Google Cloud Platform) to make it accessible to external users." ] } ```
## Sources - Ollama api docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion - Ollama structured output docs: https://github.com/ollama/ollama/blob/main/docs/api.md#request-structured-outputs ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --- llama_stack/providers/remote/inference/ollama/ollama.py | 9 +++++++++ .../providers/tests/inference/test_text_inference.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 88f985f3a..2de5a994e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -236,6 +236,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + response_format=response_format, ) if stream: return self._stream_chat_completion(request) @@ -279,6 +280,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) input_dict["raw"] = True + if fmt := request.response_format: + if fmt.type == "json_schema": + input_dict["format"] = fmt.json_schema + elif fmt.type == "grammar": + raise NotImplementedError("Grammar response format is not supported") + else: + raise ValueError(f"Unknown response format type: {fmt.type}") + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 2eeda0dbf..fd93857a3 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -210,6 +210,7 @@ class TestInference: provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::tgi", "remote::together", "remote::fireworks", @@ -272,6 +273,7 @@ class TestInference: provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::fireworks", "remote::tgi", "remote::together", From 49ad16833694b27d710fced59a2720c6a2a0b257 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Fri, 3 Jan 2025 04:21:35 +1100 Subject: [PATCH 06/12] [#407] Agents: Avoid calling tools that haven't been explicitly enabled (#637) # What does this PR do? Contributes to issue (#407) tl;dr - @subramen was getting a 500 error because llama-stack called code_interpreter when it never was defined as a tool. Prevents failures like: image ``` # Server side Traceback (most recent call last): File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/distribution/server/server.py", line 206, in sse_generator async for item in await event_gen: File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/agents/agents.py", line 138, in _create_agent_turn_streaming async for event in agent.create_and_execute_turn(request): File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/agents/agent_instance.py", line 179, in create_and_execute_turn async for chunk in self.run( File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/agents/agent_instance.py", line 252, in run async for res in self._run( File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/agents/agent_instance.py", line 560, in _run result_messages = await execute_tool_call_maybe( File "/opt/conda/envs/llamastack-vllm-stack/lib/python3.10/site-packages/llama_stack/providers/impls/meta_reference/agents/agent_instance.py", line 824, in execute_tool_call_maybe assert name in tools_dict, f"Tool {name} not found" AssertionError: Tool code_interpreter not found ``` Instead, if the model hallucinates, we just let it hallucinate and let the client know. image ## Test Plan
pytest llama_stack/providers/tests/agents/test_agents.py -k ollama ``` llama stack build --template ollama --image-type conda conda activate llamastack-ollama ``` ``` llama_stack/providers/tests/agents/test_agents.py ..Fss [100%] ======================================================================= FAILURES ======================================================================= _________________________________________ TestAgents.test_rag_agent_as_attachments[--ollama][ollama] __________________________________________ llama_stack/providers/tests/agents/test_agents.py:261: in test_rag_agent_as_attachments turn_response = [ llama_stack/providers/tests/agents/test_agents.py:261: in turn_response = [ llama_stack/providers/inline/agents/meta_reference/agents.py:153: in _create_agent_turn_streaming async for event in agent.create_and_execute_turn(request): llama_stack/providers/inline/agents/meta_reference/agent_instance.py:179: in create_and_execute_turn async for chunk in self.run( llama_stack/providers/inline/agents/meta_reference/agent_instance.py:250: in run async for res in self._run( llama_stack/providers/inline/agents/meta_reference/agent_instance.py:363: in _run rag_context, bank_ids = await self._retrieve_context( llama_stack/providers/inline/agents/meta_reference/agent_instance.py:698: in _retrieve_context bank_id = await self._ensure_memory_bank(session_id) llama_stack/providers/inline/agents/meta_reference/agent_instance.py:653: in _ensure_memory_bank await self.memory_banks_api.register_memory_bank( llama_stack/providers/utils/telemetry/trace_protocol.py:101: in async_wrapper result = await method(self, *args, **kwargs) llama_stack/distribution/routers/routing_tables.py:312: in register_memory_bank raise ValueError( E ValueError: Embeddings are now served via Inference providers. Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example. =============================================================== short test summary info ================================================================ FAILED llama_stack/providers/tests/agents/test_agents.py::TestAgents::test_rag_agent_as_attachments[--ollama] - ValueError: Embeddings are now served via Inference providers. Please upgrade your run.yaml to include inline::sentence-transformer as an additiona... ========================================== 1 failed, 2 passed, 2 skipped, 20 deselected, 5 warnings in 14.24s ========================================== ``` Unrelated test is failing (also failing on main)
Manual Using this client code: https://github.com/aidando73/llama-stack-apps/blob/7ebc257b27bb120fe13e11d9d668a467a33e137d/client.py Screenshot 2024-12-16 at 17 41 31
## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f225f5393..09738d7b7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -584,7 +584,7 @@ class ChatAgent(ShieldRunnerMixin): tool_call = message.tool_calls[0] name = tool_call.tool_name - if not isinstance(name, BuiltinTool): + if not isinstance(name, BuiltinTool) or name not in enabled_tools: yield message return From 8e5b33679224a4d747cc01989a9b9c0cee5d2465 Mon Sep 17 00:00:00 2001 From: Justin Lee Date: Fri, 3 Jan 2025 03:18:07 +0800 Subject: [PATCH 07/12] Made changes to readme and pinning to llamastack v0.0.61 (#624) # What does this PR do? Pinning zero2hero to 0.0.61 and updated readme ## Test Plan Please describe: - Did a end to end test on the server and inference for 0.0.61 Server output: image ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/zero_to_hero_guide/00_Inference101.ipynb | 12 +--- docs/zero_to_hero_guide/README.md | 68 ++++++++++--------- 2 files changed, 36 insertions(+), 44 deletions(-) diff --git a/docs/zero_to_hero_guide/00_Inference101.ipynb b/docs/zero_to_hero_guide/00_Inference101.ipynb index 2aced6ef9..687f5606b 100644 --- a/docs/zero_to_hero_guide/00_Inference101.ipynb +++ b/docs/zero_to_hero_guide/00_Inference101.ipynb @@ -358,7 +358,7 @@ " if not stream:\n", " cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", " else:\n", - " async for log in EventLogger().log(response):\n", + " for log in EventLogger().log(response):\n", " log.print()\n", "\n", "# In a Jupyter Notebook cell, use `await` to call the function\n", @@ -366,16 +366,6 @@ "# To run it in a python file, use this line instead\n", "# asyncio.run(run_main())\n" ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "9399aecc", - "metadata": {}, - "outputs": [], - "source": [ - "#fin" - ] } ], "metadata": { diff --git a/docs/zero_to_hero_guide/README.md b/docs/zero_to_hero_guide/README.md index 68c012164..b451e0af7 100644 --- a/docs/zero_to_hero_guide/README.md +++ b/docs/zero_to_hero_guide/README.md @@ -45,7 +45,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next --- -## Install Dependencies and Set Up Environment +## Install Dependencies and Set Up Environmen 1. **Create a Conda Environment**: Create a new Conda environment with Python 3.10: @@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next Open a new terminal and install `llama-stack`: ```bash conda activate ollama - pip install llama-stack==0.0.55 + pip install llama-stack==0.0.61 ``` --- @@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next 3. **Set the ENV variables by exporting them to the terminal**: ```bash export OLLAMA_URL="http://localhost:11434" - export LLAMA_STACK_PORT=5051 + export LLAMA_STACK_PORT=5001 export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" ``` @@ -104,34 +104,29 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next 3. **Run the Llama Stack**: Run the stack with command shared by the API from earlier: ```bash - llama stack run ollama \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=$INFERENCE_MODEL \ - --env SAFETY_MODEL=$SAFETY_MODEL \ + llama stack run ollama + --port $LLAMA_STACK_PORT + --env INFERENCE_MODEL=$INFERENCE_MODEL + --env SAFETY_MODEL=$SAFETY_MODEL --env OLLAMA_URL=$OLLAMA_URL ``` Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. -The server will start and listen on `http://localhost:5051`. +The server will start and listen on `http://localhost:5001`. --- ## Test with `llama-stack-client` CLI -After setting up the server, open a new terminal window and install the llama-stack-client package. +After setting up the server, open a new terminal window and configure the llama-stack-client. -1. Install the llama-stack-client package +1. Configure the CLI to point to the llama-stack server. ```bash - conda activate ollama - pip install llama-stack-client - ``` -2. Configure the CLI to point to the llama-stack server. - ```bash - llama-stack-client configure --endpoint http://localhost:5051 + llama-stack-client configure --endpoint http://localhost:5001 ``` **Expected Output:** ```bash - Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5051 + Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001 ``` -3. Test the CLI by running inference: +2. Test the CLI by running inference: ```bash llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon" ``` @@ -153,16 +148,18 @@ After setting up the server, open a new terminal window and install the llama-st After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`: ```bash -curl http://localhost:$LLAMA_STACK_PORT/inference/chat_completion \ --H "Content-Type: application/json" \ --d '{ - "model": "Llama3.2-3B-Instruct", +curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion +-H "Content-Type: application/json" +-d @- < Date: Thu, 2 Jan 2025 11:21:33 -0800 Subject: [PATCH 08/12] [rag evals] refactor & add ability to eval retrieval + generation in agentic eval pipeline (#664) # What does this PR do? - See https://github.com/meta-llama/llama-stack/pull/666 & https://github.com/meta-llama/llama-stack/pull/668 - Refactor BaseScoringFn to be just a minimal interface, add new RegistrableBaseScoring - Refactor data schema check - To separately evaluate retrieval component in RAG, we will have scoring functions needing "context" column additionally. - Refactor braintrust eval (more scoring fn added & tested in following PR) ## Test Plan ``` pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py ``` image ``` pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio eval/test_eval.py ``` image ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/apis/scoring/scoring.py | 4 +- .../inline/eval/meta_reference/eval.py | 72 ++++----- .../providers/inline/scoring/basic/scoring.py | 34 ++-- .../basic/scoring_fn/equality_scoring_fn.py | 4 +- .../scoring_fn/regex_parser_scoring_fn.py | 4 +- .../basic/scoring_fn/subset_of_scoring_fn.py | 4 +- .../inline/scoring/braintrust/braintrust.py | 149 ++++++++++++++---- .../scoring_fn/fn_defs/answer_correctness.py | 15 +- .../scoring_fn/fn_defs/answer_relevancy.py | 26 +++ .../scoring_fn/fn_defs/answer_similarity.py | 26 +++ .../fn_defs/context_entity_recall.py | 26 +++ .../scoring_fn/fn_defs/context_precision.py | 26 +++ .../scoring_fn/fn_defs/context_recall.py | 26 +++ .../scoring_fn/fn_defs/context_relevancy.py | 26 +++ .../scoring_fn/fn_defs/factuality.py | 15 +- .../scoring_fn/fn_defs/faithfulness.py | 26 +++ .../inline/scoring/llm_as_judge/scoring.py | 32 ++-- .../scoring_fn/llm_as_judge_scoring_fn.py | 4 +- .../tests/datasetio/test_datasetio.py | 17 +- .../tests/datasetio/test_rag_dataset.csv | 6 + .../providers/tests/scoring/test_scoring.py | 6 +- .../providers/utils/common/__init__.py | 5 + .../utils/common/data_schema_validator.py | 87 ++++++++++ .../utils/scoring/base_scoring_fn.py | 43 ++++- 24 files changed, 544 insertions(+), 139 deletions(-) create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py create mode 100644 llama_stack/providers/tests/datasetio/test_rag_dataset.csv create mode 100644 llama_stack/providers/utils/common/__init__.py create mode 100644 llama_stack/providers/utils/common/data_schema_validator.py diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 453e35f6d..996291dcc 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -47,7 +47,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]], save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -55,5 +55,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]], ) -> ScoreResponse: ... diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 00630132e..b555c9f2a 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -3,23 +3,24 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Optional from tqdm import tqdm -from llama_stack.apis.agents import Agents -from llama_stack.apis.common.type_system import ( - ChatCompletionInputType, - CompletionInputType, - StringType, -) +from llama_stack.apis.agents import Agents, StepType from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring +from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import EvalTasksProtocolPrivate + +from llama_stack.providers.utils.common.data_schema_validator import ( + ColumnName, + DataSchemaValidatorMixin, + get_valid_schemas, +) from llama_stack.providers.utils.kvstore import kvstore_impl from .....apis.common.job_types import Job @@ -30,15 +31,7 @@ from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" -class ColumnName(Enum): - input_query = "input_query" - expected_answer = "expected_answer" - chat_completion_input = "chat_completion_input" - completion_input = "completion_input" - generated_answer = "generated_answer" - - -class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): +class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin): def __init__( self, config: MetaReferenceEvalConfig, @@ -82,29 +75,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): ) self.eval_tasks[task_def.identifier] = task_def - async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - expected_schemas = [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.chat_completion_input.value: ChatCompletionInputType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.completion_input.value: CompletionInputType(), - }, - ] - - if dataset_def.dataset_schema not in expected_schemas: - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" - ) - async def run_eval( self, task_id: str, @@ -114,8 +84,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): dataset_id = task_def.dataset_id candidate = task_config.eval_candidate scoring_functions = task_def.scoring_functions - - await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) + ) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=( @@ -167,11 +139,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): ) ] final_event = turn_response[-1].event.payload - generations.append( - { - ColumnName.generated_answer.value: final_event.turn.output_message.content - } + + # check if there's a memory retrieval step and extract the context + memory_rag_context = None + for step in final_event.turn.steps: + if step.step_type == StepType.memory_retrieval.value: + memory_rag_context = " ".join(x.text for x in step.inserted_context) + + agent_generation = {} + agent_generation[ColumnName.generated_answer.value] = ( + final_event.turn.output_message.content ) + if memory_rag_context: + agent_generation[ColumnName.context.value] = memory_rag_context + + generations.append(agent_generation) return generations diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index f8b30cbcf..f612abda4 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -14,8 +14,13 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator import ( + DataSchemaValidatorMixin, + get_valid_schemas, +) from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn @@ -24,7 +29,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] -class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class BasicScoringImpl( + Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin +): def __init__( self, config: BasicScoringConfig, @@ -61,30 +68,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def register_scoring_function(self, function_def: ScoringFn) -> None: raise NotImplementedError("Register scoring function not implemented yet") - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) - async def score_batch( self, dataset_id: str, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 9991c5502..9b0566228 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -9,12 +9,12 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.equality import equality -class EqualityScoringFn(BaseScoringFn): +class EqualityScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 552f34d46..38014ca6f 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -9,14 +9,14 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, ) -class RegexParserScoringFn(BaseScoringFn): +class RegexParserScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that parses answer from generated response according to context and check match with expected_answer. """ diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 29ae12e44..71defc433 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -8,12 +8,12 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.subset_of import subset_of -class SubsetOfScoringFn(BaseScoringFn): +class SubsetOfScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. """ diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 0c6102645..4282ef6ec 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -7,7 +7,17 @@ import os from typing import Any, Dict, List, Optional from autoevals.llm import Factuality -from autoevals.ragas import AnswerCorrectness +from autoevals.ragas import ( + AnswerCorrectness, + AnswerRelevancy, + AnswerSimilarity, + ContextEntityRecall, + ContextPrecision, + ContextRecall, + ContextRelevancy, + Faithfulness, +) +from pydantic import BaseModel from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -18,20 +28,90 @@ from llama_stack.apis.scoring import ( ScoringResult, ScoringResultRow, ) -from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams + +from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator import ( + DataSchemaValidatorMixin, + get_valid_schemas, +) -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average - +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def +from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def +from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def +from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def +from .scoring_fn.fn_defs.context_precision import context_precision_fn_def +from .scoring_fn.fn_defs.context_recall import context_recall_fn_def +from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def +from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def + + +class BraintrustScoringFnEntry(BaseModel): + identifier: str + evaluator: Any + fn_def: ScoringFn + + +SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [ + BraintrustScoringFnEntry( + identifier="braintrust::factuality", + evaluator=Factuality(), + fn_def=factuality_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::answer-correctness", + evaluator=AnswerCorrectness(), + fn_def=answer_correctness_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::answer-relevancy", + evaluator=AnswerRelevancy(), + fn_def=answer_relevancy_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::answer-similarity", + evaluator=AnswerSimilarity(), + fn_def=answer_similarity_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::faithfulness", + evaluator=Faithfulness(), + fn_def=faithfulness_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::context-entity-recall", + evaluator=ContextEntityRecall(), + fn_def=context_entity_recall_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::context-precision", + evaluator=ContextPrecision(), + fn_def=context_precision_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::context-recall", + evaluator=ContextRecall(), + fn_def=context_recall_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::context-relevancy", + evaluator=ContextRelevancy(), + fn_def=context_relevancy_fn_def, + ), +] class BraintrustScoringImpl( - Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData + Scoring, + ScoringFunctionsProtocolPrivate, + NeedsRequestProviderData, + DataSchemaValidatorMixin, ): def __init__( self, @@ -44,12 +124,12 @@ class BraintrustScoringImpl( self.datasets_api = datasets_api self.braintrust_evaluators = { - "braintrust::factuality": Factuality(), - "braintrust::answer-correctness": AnswerCorrectness(), + entry.identifier: entry.evaluator + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } self.supported_fn_defs_registry = { - factuality_fn_def.identifier: factuality_fn_def, - answer_correctness_fn_def.identifier: answer_correctness_fn_def, + entry.identifier: entry.fn_def + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } async def initialize(self) -> None: ... @@ -70,23 +150,6 @@ class BraintrustScoringImpl( "Registering scoring function not allowed for braintrust provider" ) - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) - async def set_api_key(self) -> None: # api key is in the request headers if not self.config.openai_api_key: @@ -102,11 +165,16 @@ class BraintrustScoringImpl( async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]], save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.set_api_key() - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, @@ -126,6 +194,7 @@ class BraintrustScoringImpl( async def score_row( self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None ) -> ScoringResultRow: + self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) await self.set_api_key() assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" expected_answer = input_row["expected_answer"] @@ -133,12 +202,19 @@ class BraintrustScoringImpl( input_query = input_row["input_query"] evaluator = self.braintrust_evaluators[scoring_fn_identifier] - result = evaluator(generated_answer, expected_answer, input=input_query) + result = evaluator( + generated_answer, + expected_answer, + input=input_query, + context=input_row["context"] if "context" in input_row else None, + ) score = result.score return {"score": score, "metadata": result.metadata} async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]], ) -> ScoreResponse: await self.set_api_key() res = {} @@ -150,8 +226,17 @@ class BraintrustScoringImpl( await self.score_row(input_row, scoring_fn_id) for input_row in input_rows ] - aggregation_functions = [AggregationFunctionType.average] - agg_results = aggregate_average(score_results) + aggregation_functions = self.supported_fn_defs_registry[ + scoring_fn_id + ].params.aggregation_functions + + # override scoring_fn params if provided + if scoring_functions[scoring_fn_id] is not None: + override_params = scoring_functions[scoring_fn_id] + if override_params.aggregation_functions: + aggregation_functions = override_params.aggregation_functions + + agg_results = aggregate_metrics(score_results, aggregation_functions) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py index dc5df8e78..526ba2c37 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py @@ -5,14 +5,23 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) answer_correctness_fn_def = ScoringFn( identifier="braintrust::answer-correctness", - description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - params=None, + description=( + "Scores the correctness of the answer based on the ground truth. " + "Uses Braintrust LLM-based scorer from autoevals library." + ), provider_id="braintrust", provider_resource_id="answer-correctness", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), ) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py new file mode 100644 index 000000000..3e3e6ac87 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +answer_relevancy_fn_def = ScoringFn( + identifier="braintrust::answer-relevancy", + description=( + "Test output relevancy against the input query using Braintrust LLM scorer. " + "See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="answer-relevancy", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py new file mode 100644 index 000000000..bea8dfd53 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +answer_similarity_fn_def = ScoringFn( + identifier="braintrust::answer-similarity", + description=( + "Test output similarity against expected value using Braintrust LLM scorer. " + "See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="answer-similarity", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py new file mode 100644 index 000000000..ac41df000 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +context_entity_recall_fn_def = ScoringFn( + identifier="braintrust::context-entity-recall", + description=( + "Evaluates how well the context captures the named entities present in the " + "reference answer. See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="context-entity-recall", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py new file mode 100644 index 000000000..ef172d82c --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +context_precision_fn_def = ScoringFn( + identifier="braintrust::context-precision", + description=( + "Measures how much of the provided context is actually relevant to answering the " + "question. See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="context-precision", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py new file mode 100644 index 000000000..d4561a5d4 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +context_recall_fn_def = ScoringFn( + identifier="braintrust::context-recall", + description=( + "Evaluates how well the context covers the information needed to answer the " + "question. See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="context-recall", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py new file mode 100644 index 000000000..06fc86a7b --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +context_relevancy_fn_def = ScoringFn( + identifier="braintrust::context-relevancy", + description=( + "Assesses how relevant the provided context is to the given question. " + "See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="context-relevancy", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py index b733f10c8..a4d597c29 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py @@ -5,14 +5,23 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) factuality_fn_def = ScoringFn( identifier="braintrust::factuality", - description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - params=None, + description=( + "Test output factuality against expected value using Braintrust LLM scorer. " + "See: github.com/braintrustdata/autoevals" + ), provider_id="braintrust", provider_resource_id="factuality", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), ) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py new file mode 100644 index 000000000..9cffff558 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +faithfulness_fn_def = ScoringFn( + identifier="braintrust::faithfulness", + description=( + "Test output faithfulness to the input query using Braintrust LLM scorer. " + "See: github.com/braintrustdata/autoevals" + ), + provider_id="braintrust", + provider_resource_id="faithfulness", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 09780e6fb..305c13665 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -16,7 +16,12 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator import ( + DataSchemaValidatorMixin, + get_valid_schemas, +) from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn @@ -25,7 +30,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] -class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class LlmAsJudgeScoringImpl( + Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin +): def __init__( self, config: LlmAsJudgeScoringConfig, @@ -65,30 +72,17 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def register_scoring_function(self, function_def: ScoringFn) -> None: raise NotImplementedError("Register scoring function not implemented yet") - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) - async def score_batch( self, dataset_id: str, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 00ea53c8f..027709f74 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa from .fn_defs.llm_as_judge_base import llm_as_judge_base -class LlmAsJudgeScoringFn(BaseScoringFn): +class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns """ diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 46c99f5b3..cf28045a4 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -38,9 +38,15 @@ def data_url_from_file(file_path: str) -> str: async def register_dataset( - datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" + datasets_impl: Datasets, + for_generation=False, + for_rag=False, + dataset_id="test_dataset", ): - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + if for_rag: + test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv" + else: + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" test_url = data_url_from_file(str(test_file)) if for_generation: @@ -49,6 +55,13 @@ async def register_dataset( "input_query": StringType(), "chat_completion_input": ChatCompletionInputType(), } + elif for_rag: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + "context": StringType(), + } else: dataset_schema = { "expected_answer": StringType(), diff --git a/llama_stack/providers/tests/datasetio/test_rag_dataset.csv b/llama_stack/providers/tests/datasetio/test_rag_dataset.csv new file mode 100644 index 000000000..a0e1fce72 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_rag_dataset.csv @@ -0,0 +1,6 @@ +input_query,context,generated_answer,expected_answer +What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris +Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg +What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter +What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City +What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 2643b8fd6..00dd5d27b 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -60,7 +60,7 @@ class TestScoring: f"{provider_id} provider does not support scoring without params" ) - await register_dataset(datasets_impl) + await register_dataset(datasets_impl, for_rag=True) response = await datasets_impl.list_datasets() assert len(response) == 1 @@ -112,7 +112,7 @@ class TestScoring: scoring_stack[Api.datasets], scoring_stack[Api.models], ) - await register_dataset(datasets_impl) + await register_dataset(datasets_impl, for_rag=True) response = await datasets_impl.list_datasets() assert len(response) == 1 @@ -173,7 +173,7 @@ class TestScoring: scoring_stack[Api.datasets], scoring_stack[Api.models], ) - await register_dataset(datasets_impl) + await register_dataset(datasets_impl, for_rag=True) rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", rows_in_page=3, diff --git a/llama_stack/providers/utils/common/__init__.py b/llama_stack/providers/utils/common/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/common/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py new file mode 100644 index 000000000..d9e6cb6b5 --- /dev/null +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) + +from llama_stack.distribution.datatypes import Api + + +class ColumnName(Enum): + input_query = "input_query" + expected_answer = "expected_answer" + chat_completion_input = "chat_completion_input" + completion_input = "completion_input" + generated_answer = "generated_answer" + context = "context" + + +VALID_SCHEMAS_FOR_SCORING = [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.context.value: StringType(), + }, +] + +VALID_SCHEMAS_FOR_EVAL = [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.completion_input.value: CompletionInputType(), + }, +] + + +def get_valid_schemas(api_str: str): + if api_str == Api.scoring.value: + return VALID_SCHEMAS_FOR_SCORING + elif api_str == Api.eval.value: + return VALID_SCHEMAS_FOR_EVAL + else: + raise ValueError(f"Invalid API string: {api_str}") + + +class DataSchemaValidatorMixin: + def validate_dataset_schema( + self, + dataset_schema: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + if dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}" + ) + + def validate_row_schema( + self, + input_row: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + for schema in expected_schemas: + if all(key in input_row for key in schema): + return + + raise ValueError( + f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" + ) diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 2db77fd2b..e0e557374 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -13,12 +13,51 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr class BaseScoringFn(ABC): """ - Base interface class for all native scoring_fns. - Each scoring_fn needs to implement the following methods: + Base interface class for Scoring Functions. + Each scoring function needs to implement the following methods: - score_row(self, row) - aggregate(self, scoring_fn_results) """ + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __str__(self) -> str: + return self.__class__.__name__ + + @abstractmethod + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + raise NotImplementedError() + + @abstractmethod + async def aggregate( + self, + scoring_results: List[ScoringResultRow], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> Dict[str, Any]: + raise NotImplementedError() + + @abstractmethod + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> List[ScoringResultRow]: + raise NotImplementedError() + + +class RegisteredBaseScoringFn(BaseScoringFn): + """ + Interface for native scoring functions that are registered in LlamaStack. + """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = {} From b438e616ffca53bdea8c3a171932c25c35447795 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 2 Jan 2025 11:26:19 -0800 Subject: [PATCH 09/12] kill api key from notebook --- docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb index fa527f1a0..d061603c8 100644 --- a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb +++ b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb @@ -544,7 +544,7 @@ " provider_type: inline::meta-reference\n", " inference:\n", " - config:\n", - " api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", + " api_key: <...>\n", " url: https://api.together.xyz/v1\n", " provider_id: together\n", " provider_type: remote::together\n", @@ -663,7 +663,7 @@ " provider_type: inline::meta-reference\n", " inference:\n", " - config:\n", - " api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", + " api_key: <...>\n", " url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n", " provider_id: together\n", " provider_type: remote::together\n", From 750604c7af8d983ed8e6d94b6d129efb6ffdcedc Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 2 Jan 2025 13:08:20 -0800 Subject: [PATCH 10/12] [Post Training] Fix missing import (#705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## context Post training apis are broken after the import * refactor https://github.com/meta-llama/llama-stack/pull/689. This PR is adding the missing import back ## Test Issue a post training request from client and the training finishes successfully Screenshot 2025-01-02 at 12 18 45 PM Screenshot 2025-01-02 at 12 18 52 PM --- .../providers/inline/post_training/torchtune/common/utils.py | 2 ++ .../torchtune/recipes/lora_finetuning_single_device.py | 1 + 2 files changed, 3 insertions(+) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index f2a2edae5..9673e0732 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -15,6 +15,8 @@ from typing import Any, Callable, Dict, List import torch from llama_models.datatypes import Model + +from llama_models.llama3.api.datatypes import BaseModel from llama_models.sku_list import resolve_model from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.datasets import Datasets diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 517be6d89..1b6c508a7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -7,6 +7,7 @@ import logging import os import time +from datetime import datetime from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Tuple From d9f75cc98fbb4172751c97e191ec8df819c92b2a Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 2 Jan 2025 13:15:31 -0800 Subject: [PATCH 11/12] Import from the right path (#708) Import BaseModel and Field from pydantic --- llama_stack/apis/eval/eval.py | 3 ++- .../providers/inline/post_training/torchtune/common/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 2592bca37..1073d6310 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -6,9 +6,10 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.llama3.api.datatypes import BaseModel, Field from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field + from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 9673e0732..a5279cdbe 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -15,12 +15,12 @@ from typing import Any, Callable, Dict, List import torch from llama_models.datatypes import Model - -from llama_models.llama3.api.datatypes import BaseModel from llama_models.sku_list import resolve_model from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.datasets import Datasets +from pydantic import BaseModel + from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b From e3f187fb83f2c45d5f838663658a873fb0fcc6d9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 2 Jan 2025 11:40:48 -0800 Subject: [PATCH 12/12] Redact sensitive information from configs when printing, etc. --- llama_stack/distribution/library_client.py | 6 +++++- llama_stack/distribution/server/server.py | 4 +++- llama_stack/distribution/stack.py | 20 +++++++++++++++++++ .../remote/inference/cerebras/cerebras.py | 3 ++- .../remote/inference/cerebras/config.py | 4 ++-- .../remote/inference/fireworks/config.py | 4 ++-- .../remote/inference/fireworks/fireworks.py | 2 +- .../remote/inference/nvidia/config.py | 4 ++-- .../remote/inference/nvidia/nvidia.py | 6 +++++- .../providers/remote/inference/tgi/config.py | 8 ++++---- .../providers/remote/inference/tgi/tgi.py | 8 +++++--- .../remote/inference/together/config.py | 4 ++-- .../remote/inference/together/together.py | 2 +- 13 files changed, 54 insertions(+), 21 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 48fcc437b..01b8bb3b5 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -39,6 +39,7 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( construct_stack, get_stack_run_config_from_template, + redact_sensitive_fields, replace_env_vars, ) @@ -273,7 +274,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): console = Console() console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") - console.print(yaml.dump(self.config.model_dump(), indent=2)) + + # Redact sensitive information before printing + safe_config = redact_sensitive_fields(self.config.model_dump()) + console.print(yaml.dump(safe_config, indent=2)) endpoints = get_all_api_endpoints() endpoint_impls = {} diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index daaf8475b..e432cca4e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,6 +35,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( construct_stack, + redact_sensitive_fields, replace_env_vars, validate_env_pair, ) @@ -280,7 +281,8 @@ def main(): config = StackRunConfig(**config) print("Run configuration:") - print(yaml.dump(config.model_dump(), indent=2)) + safe_config = redact_sensitive_fields(config.model_dump()) + print(yaml.dump(safe_config, indent=2)) app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 965df5f03..7fc2c7650 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -112,6 +112,26 @@ class EnvVarError(Exception): ) +def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: + """Redact sensitive information from config before printing.""" + sensitive_patterns = ["api_key", "api_token", "password", "secret"] + + def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]: + result = {} + for k, v in d.items(): + if isinstance(v, dict): + result[k] = _redact_dict(v) + elif isinstance(v, list): + result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v] + elif any(pattern in k.lower() for pattern in sensitive_patterns): + result[k] = "********" + else: + result[k] = v + return result + + return _redact_dict(data) + + def replace_env_vars(config: Any, path: str = "") -> Any: if isinstance(config, dict): result = {} diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 40457e1ae..586447012 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = AsyncCerebras( - base_url=self.config.base_url, api_key=self.config.api_key + base_url=self.config.base_url, + api_key=self.config.api_key.get_secret_value(), ) async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 9bae6ca4d..6eb4dffec 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -8,7 +8,7 @@ import os from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr DEFAULT_BASE_URL = "https://api.cerebras.ai" @@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel): default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), description="Base URL for the Cerebras API", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default=os.environ.get("CEREBRAS_API_KEY"), description="Cerebras API Key", ) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 979e8455a..d84a00d56 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr @json_schema_type @@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel): default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default=None, description="The Fireworks.ai API Key", ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 7a00194ac..6706e9f4a 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -113,7 +113,7 @@ class FireworksInferenceAdapter( def _get_api_key(self) -> str: if self.config.api_key is not None: - return self.config.api_key + return self.config.api_key.get_secret_value() else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 28be43f4c..9e81211bd 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -8,7 +8,7 @@ import os from typing import Optional from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr @json_schema_type @@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel): ), description="A base url for accessing the NVIDIA NIM", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default_factory=lambda: os.getenv("NVIDIA_API_KEY"), description="The NVIDIA API key, only needed of using the hosted service", ) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 585ad83c7..42c4db53e 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # make sure the client lives longer than any async calls self._client = AsyncOpenAI( base_url=f"{self._config.url}/v1", - api_key=self._config.api_key or "NO KEY", + api_key=( + self._config.api_key.get_secret_value() + if self._config.api_key + else "NO KEY" + ), timeout=self._config.timeout, ) diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index 230eaacab..f05005b25 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -7,7 +7,7 @@ from typing import Optional from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr @json_schema_type @@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel): url: str = Field( description="The URL for the TGI serving endpoint", ) - api_token: Optional[str] = Field( + api_token: Optional[SecretStr] = Field( default=None, description="A bearer token if your TGI endpoint is protected.", ) @@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel): endpoint_name: str = Field( description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", ) - api_token: Optional[str] = Field( + api_token: Optional[SecretStr] = Field( default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) @@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel): huggingface_repo: str = Field( description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", ) - api_token: Optional[str] = Field( + api_token: Optional[SecretStr] = Field( default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index dd02c055a..25d2e0cb8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: log.info(f"Initializing TGI client with url={config.url}") - self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + self.client = AsyncInferenceClient( + model=config.url, token=config.api_token.get_secret_value() + ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] @@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: self.client = AsyncInferenceClient( - model=config.huggingface_repo, token=config.api_token + model=config.huggingface_repo, token=config.api_token.get_secret_value() ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] @@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter): class InferenceEndpointAdapter(_HfAdapter): async def initialize(self, config: InferenceEndpointImplConfig) -> None: # Get the inference endpoint details - api = HfApi(token=config.api_token) + api = HfApi(token=config.api_token.get_secret_value()) endpoint = api.get_inference_endpoint(config.endpoint_name) # Wait for the endpoint to be ready (if not already) diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index ecbe9ec06..a56cb5bb8 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr @json_schema_type @@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel): default="https://api.together.xyz/v1", description="The URL for the Together AI server", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default=None, description="The Together AI API Key", ) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 6b5a6a3b0..f8e889ab3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -130,7 +130,7 @@ class TogetherInferenceAdapter( def _get_client(self) -> Together: together_api_key = None if self.config.api_key is not None: - together_api_key = self.config.api_key + together_api_key = self.config.api_key.get_secret_value() else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: