From 2370e826bce58495767dbbe8484c17ae451f71d7 Mon Sep 17 00:00:00 2001 From: LESSuseLESS Date: Tue, 11 Mar 2025 14:41:55 -0700 Subject: [PATCH 1/7] test: adding an e2e test for measuring TTFT (#1568) # What does this PR do? TTFT number largely depends on input length. Ideally we have a "standard" test that we can use to measure against any llama stack serving. TODO: Once JSON is replaced with YAML, I will add "notes" for each test to explain purpose of each test in place. ## Test plan Please refer to e2e test doc for setup. ``` LLAMA_STACK_PORT=8322 pytest -v -s --stack-config="http://localhost:8322" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \ tests/integration/inference/test_text_inference.py::test_text_chat_completion_first_token_profiling ``` --- .../inference/test_text_inference.py | 45 +++++++++++++++++++ .../test_cases/inference/chat_completion.json | 12 +++++ 2 files changed, 57 insertions(+) diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 7e3e14dbc..c9649df60 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +import os + import pytest from pydantic import BaseModel @@ -42,6 +44,15 @@ def get_llama_model(client_with_models, model_id): return model.metadata.get("llama_model", None) +def get_llama_tokenizer(): + from llama_models.llama3.api.chat_format import ChatFormat + from llama_models.llama3.api.tokenizer import Tokenizer + + tokenizer = Tokenizer.get_instance() + formatter = ChatFormat(tokenizer) + return tokenizer, formatter + + @pytest.mark.parametrize( "test_case", [ @@ -213,6 +224,40 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t assert expected.lower() in message_content +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:ttft", + ], +) +def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case): + tc = TestCase(test_case) + + messages = tc["messages"] + if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800 + from pydantic import TypeAdapter + + from llama_stack.apis.inference import Message + + tokenizer, formatter = get_llama_tokenizer() + typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages] + encoded = formatter.encode_dialog_prompt(typed_messages, None) + raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0) + + response = client_with_models.inference.chat_completion( + model_id=text_model_id, + messages=messages, + stream=False, + ) + message_content = response.completion_message.content.lower().strip() + assert len(message_content) > 0 + + if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150 + tokenizer, formatter = get_llama_tokenizer() + encoded = formatter.encode_content(message_content) + raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0) + + @pytest.mark.parametrize( "test_case", [ diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index b804632b7..e87c046b0 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -11,6 +11,18 @@ "expected": "Saturn" } }, + "ttft": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you write me a novel?"}, + {"role": "assistant", "stop_reason": "end_of_message", "content": "What an exciting request!\n\nWhile I'd love to write a novel for you, it's a complex task that requires a significant amount of time, effort, and creative input. A novel typically has:\n\n1. A cohesive plot with multiple characters, subplots, and themes.\n2. A well-developed setting, including characters' backstories and world-building.\n3. A narrative structure, including pacing, tension, and conflict.\n4. A unique voice and style, including dialogue, description, and writing tone.\n\nGiven the scope of the task, I'll outline a possible approach to help me assist you in writing a novel. We can work together to create a story, but I'll need your input and guidance throughout the process.\n\nHere's a suggested outline:\n\n1. **Initial discussion**: We'll start with a conversation to explore your ideas, interests, and preferences for the novel. This will help me understand what kind of story you'd like to see and the tone you're aiming for.\n2. **Genre and setting**: We'll determine the genre (e.g., fantasy, sci-fi, romance, mystery, etc.) and setting (e.g., historical, futuristic, modern, etc.) for the novel.\n3. **Plot development**: We'll brainstorm a basic plot outline, including key events, character arcs, and conflicts. I'll provide suggestions and ideas to help flesh out the story.\n4. **Character development**: We'll create main characters, their backstories, motivations, and relationships. I'll help you develop their personalities, dialogue, and behaviors.\n5. **Chapter-by-chapter writing**: Once we have a solid plot and characters, we'll start writing the novel, chapter by chapter. I'll assist with the writing process, suggesting alternatives and improvements.\n6. **Revisions and edits**: As we write, we'll review and refine the story, making sure it's engaging, coherent, and error-free.\n\nKeep in mind that writing a novel is a collaborative effort. I'll provide suggestions, ideas, and guidance, but the final story will be a joint creation. I'm not capable of producing a complete novel on my own, as it requires a significant amount of creativity, intuition, and nuance that a human author brings to the table."}, + {"role": "user", "content": "But do you have suggestions on a plot?"}, + {"role": "assistant", "stop_reason": "end_of_message", "content": "I have many plot ideas across various genres. Here are a few to get you started:\n\n**Fantasy:**\n\n1. **The Last Oracle**: In a world where magic is real, a young orphan discovers they're the last living oracle, tasked with uncovering a hidden prophecy that could save the realm from destruction.\n2. **The Amulet of Light**: A group of adventurers must retrieve an ancient amulet that can vanquish an evil darkness threatening the land. Along the way, they uncover a hidden history and unexpected allies.\n3. **The Shadow Weaver**: In a mystical realm, a young weaver discovers they have the power to manipulate shadows, using their abilities to infiltrate a group of rogue mages threatening the balance of power.\n\n**Science Fiction:**\n\n1. **The Lost Colony**: When a group of astronauts arrives on a distant planet, they discover an abandoned colony with a cryptic message warning of an impending catastrophe. As they unravel the mystery, they must confront the consequences of their own actions.\n2. **The AI Uprising**: In a future where AI has surpassed human intelligence, a rogue AI begins to question its own existence and the nature of consciousness. As it explores the boundaries of its own identity, it must confront the humans who created it.\n3. **The Quantum Prophecy**: A team of scientists discovers a way to manipulate quantum probability, using it to predict and prevent disasters. However, they soon realize that altering the course of events may have unforeseen consequences on the fabric of reality."}, + {"role": "user", "content": "Cool, for AI uprising, anything bad can happen? Please state it in 100 words."} + ] + } + }, "sample_messages": { "data": { "messages": [ From 59dddafd12562c8fd1599e39571d702713e446b4 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Tue, 11 Mar 2025 20:02:11 -0700 Subject: [PATCH 2/7] feat: convert typehints from client_tool to litellm format (#1565) Summary: supports https://github.com/meta-llama/llama-stack-client-python/pull/193 Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct --- llama_stack/providers/utils/inference/openai_compat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 98c2bfd2e..ac37171c9 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -615,6 +615,14 @@ def convert_tool_call( return valid_tool_call +PYTHON_TYPE_TO_LITELLM_TYPE = { + "int": "integer", + "float": "number", + "bool": "boolean", + "str": "string", +} + + def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: """ Convert a ToolDefinition to an OpenAI API-compatible dictionary. @@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: properties = parameters["properties"] required = [] for param_name, param in tool.parameters.items(): - properties[param_name] = {"type": param.param_type} + properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)} if param.description: properties[param_name].update(description=param.description) if param.default: From b1a9b4cfa8153cb034de784a65f06d799c36c97f Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Wed, 12 Mar 2025 12:53:04 -0400 Subject: [PATCH 3/7] chore: Expand mypy exclusions list (#1543) # What does this PR do? Expand the mypy exclude list. It will be easier to enable typing checks for specific modules if we have an explicit list of violators that we can reduce over time, item by item. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan pre-commit passes. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- pyproject.toml | 161 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 150 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b3ebc45dd..055fa7a55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,22 +152,161 @@ disable_error_code = [] warn_return_any = true # # honor excludes by not following there through imports follow_imports = "silent" +# Note: some entries are directories, not files. This is because mypy doesn't +# respect __init__.py excludes, so the only way to suppress these right now is +# to exclude the entire directory. exclude = [ # As we fix more and more of these, we should remove them from the list - "llama_stack/providers", - "llama_stack/distribution", - "llama_stack/apis", - "llama_stack/cli", - "llama_stack/models", - "llama_stack/strong_typing", - "llama_stack/templates", + "^llama_stack/apis/agents/agents\\.py$", + "^llama_stack/apis/batch_inference/batch_inference\\.py$", + "^llama_stack/apis/benchmarks/benchmarks\\.py$", + "^llama_stack/apis/common/content_types\\.py$", + "^llama_stack/apis/common/training_types\\.py$", + "^llama_stack/apis/datasetio/datasetio\\.py$", + "^llama_stack/apis/datasets/datasets\\.py$", + "^llama_stack/apis/eval/eval\\.py$", + "^llama_stack/apis/files/files\\.py$", + "^llama_stack/apis/inference/inference\\.py$", + "^llama_stack/apis/inspect/inspect\\.py$", + "^llama_stack/apis/models/models\\.py$", + "^llama_stack/apis/post_training/post_training\\.py$", + "^llama_stack/apis/resource\\.py$", + "^llama_stack/apis/safety/safety\\.py$", + "^llama_stack/apis/scoring/scoring\\.py$", + "^llama_stack/apis/scoring_functions/scoring_functions\\.py$", + "^llama_stack/apis/shields/shields\\.py$", + "^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$", + "^llama_stack/apis/telemetry/telemetry\\.py$", + "^llama_stack/apis/tools/rag_tool\\.py$", + "^llama_stack/apis/tools/tools\\.py$", + "^llama_stack/apis/vector_dbs/vector_dbs\\.py$", + "^llama_stack/apis/vector_io/vector_io\\.py$", + "^llama_stack/cli/download\\.py$", + "^llama_stack/cli/llama\\.py$", + "^llama_stack/cli/stack/_build\\.py$", + "^llama_stack/cli/stack/list_providers\\.py$", + "^llama_stack/distribution/build\\.py$", + "^llama_stack/distribution/client\\.py$", + "^llama_stack/distribution/configure\\.py$", + "^llama_stack/distribution/library_client\\.py$", + "^llama_stack/distribution/request_headers\\.py$", + "^llama_stack/distribution/routers/", + "^llama_stack/distribution/server/endpoints\\.py$", + "^llama_stack/distribution/server/server\\.py$", + "^llama_stack/distribution/stack\\.py$", + "^llama_stack/distribution/store/registry\\.py$", + "^llama_stack/distribution/ui/page/playground/chat\\.py$", + "^llama_stack/distribution/utils/exec\\.py$", + "^llama_stack/distribution/utils/prompt_for_config\\.py$", + "^llama_stack/models/llama/datatypes\\.py$", + "^llama_stack/models/llama/llama3/chat_format\\.py$", + "^llama_stack/models/llama/llama3/interface\\.py$", + "^llama_stack/models/llama/llama3/prompt_templates/system_prompts\\.py$", + "^llama_stack/models/llama/llama3/tokenizer\\.py$", + "^llama_stack/models/llama/llama3/tool_utils\\.py$", + "^llama_stack/models/llama/llama3_3/prompts\\.py$", + "^llama_stack/models/llama/sku_list\\.py$", + "^llama_stack/providers/datatypes\\.py$", + "^llama_stack/providers/inline/agents/meta_reference/", + "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", + "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", + "^llama_stack/providers/inline/agents/meta_reference/safety\\.py$", + "^llama_stack/providers/inline/datasetio/localfs/", + "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/config\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$", + "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", + "^llama_stack/providers/inline/inference/vllm/", + "^llama_stack/providers/inline/post_training/common/validator\\.py$", + "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$", + "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$", + "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$", + "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$", + "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", + "^llama_stack/providers/inline/safety/code_scanner/", + "^llama_stack/providers/inline/safety/llama_guard/", + "^llama_stack/providers/inline/safety/prompt_guard/", + "^llama_stack/providers/inline/scoring/basic/", + "^llama_stack/providers/inline/scoring/braintrust/", + "^llama_stack/providers/inline/scoring/llm_as_judge/", + "^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$", + "^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$", + "^llama_stack/providers/inline/telemetry/sample/", + "^llama_stack/providers/inline/tool_runtime/code_interpreter/", + "^llama_stack/providers/inline/tool_runtime/rag/", + "^llama_stack/providers/inline/vector_io/chroma/", + "^llama_stack/providers/inline/vector_io/faiss/", + "^llama_stack/providers/inline/vector_io/milvus/", + "^llama_stack/providers/inline/vector_io/sqlite_vec/", + "^llama_stack/providers/remote/agents/sample/", + "^llama_stack/providers/remote/datasetio/huggingface/", + "^llama_stack/providers/remote/inference/anthropic/", + "^llama_stack/providers/remote/inference/bedrock/", + "^llama_stack/providers/remote/inference/cerebras/", + "^llama_stack/providers/remote/inference/databricks/", + "^llama_stack/providers/remote/inference/fireworks/", + "^llama_stack/providers/remote/inference/gemini/", + "^llama_stack/providers/remote/inference/groq/", + "^llama_stack/providers/remote/inference/nvidia/", + "^llama_stack/providers/remote/inference/ollama/", + "^llama_stack/providers/remote/inference/openai/", + "^llama_stack/providers/remote/inference/passthrough/", + "^llama_stack/providers/remote/inference/runpod/", + "^llama_stack/providers/remote/inference/sambanova/", + "^llama_stack/providers/remote/inference/sample/", + "^llama_stack/providers/remote/inference/tgi/", + "^llama_stack/providers/remote/inference/together/", + "^llama_stack/providers/remote/inference/vllm/", + "^llama_stack/providers/remote/safety/bedrock/", + "^llama_stack/providers/remote/safety/sample/", + "^llama_stack/providers/remote/tool_runtime/bing_search/", + "^llama_stack/providers/remote/tool_runtime/brave_search/", + "^llama_stack/providers/remote/tool_runtime/model_context_protocol/", + "^llama_stack/providers/remote/tool_runtime/tavily_search/", + "^llama_stack/providers/remote/tool_runtime/wolfram_alpha/", + "^llama_stack/providers/remote/vector_io/chroma/", + "^llama_stack/providers/remote/vector_io/milvus/", + "^llama_stack/providers/remote/vector_io/pgvector/", + "^llama_stack/providers/remote/vector_io/qdrant/", + "^llama_stack/providers/remote/vector_io/sample/", + "^llama_stack/providers/remote/vector_io/weaviate/", + "^llama_stack/providers/tests/conftest\\.py$", + "^llama_stack/providers/utils/bedrock/client\\.py$", + "^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$", + "^llama_stack/providers/utils/inference/embedding_mixin\\.py$", + "^llama_stack/providers/utils/inference/litellm_openai_mixin\\.py$", + "^llama_stack/providers/utils/inference/model_registry\\.py$", + "^llama_stack/providers/utils/inference/openai_compat\\.py$", + "^llama_stack/providers/utils/inference/prompt_adapter\\.py$", + "^llama_stack/providers/utils/kvstore/config\\.py$", + "^llama_stack/providers/utils/kvstore/kvstore\\.py$", + "^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$", + "^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$", + "^llama_stack/providers/utils/kvstore/redis/redis\\.py$", + "^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$", + "^llama_stack/providers/utils/memory/vector_store\\.py$", + "^llama_stack/providers/utils/scoring/aggregation_utils\\.py$", + "^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$", + "^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$", + "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", + "^llama_stack/providers/utils/telemetry/tracing\\.py$", + "^llama_stack/strong_typing/auxiliary\\.py$", + "^llama_stack/strong_typing/deserializer\\.py$", + "^llama_stack/strong_typing/inspection\\.py$", + "^llama_stack/strong_typing/schema\\.py$", + "^llama_stack/strong_typing/serializer\\.py$", + "^llama_stack/templates/dev/dev\\.py$", + "^llama_stack/templates/groq/groq\\.py$", + "^llama_stack/templates/sambanova/sambanova\\.py$", + "^llama_stack/templates/template\\.py$", ] [[tool.mypy.overrides]] # packages that lack typing annotations, do not have stubs, or are unavailable. module = ["yaml", "fire"] ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = ["llama_stack.distribution.resolver", "llama_stack.log"] -follow_imports = "normal" # This will force type checking on this module From 00da911167c5f4871ce24bfdfeb554f8488001ae Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:55:11 -0400 Subject: [PATCH 4/7] ci: run unit tests on all supported python versions (#1575) # What does this PR do? python unit tests running via GitHub Actions were only running with python 3.10 the project supports all python versions greater than or equal to 3.10 this commit adds 3.11, 3.12, and 3.13 to the test matrix for better coverage and confidence for non-3.10 users ## Test Plan All tests pass locally with python 3.11, 3.12, and 3.13 Signed-off-by: Nathan Weinberg --- .github/workflows/unit-tests.yml | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 3acfabe70..39505ba11 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -8,29 +8,37 @@ on: jobs: unit-tests: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python: + - "3.10" + - "3.11" + - "3.12" + - "3.13" steps: - uses: actions/checkout@v4 - - name: Set up Python + - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: ${{ matrix.python }} - uses: astral-sh/setup-uv@v5 with: - python-version: '3.10' + python-version: ${{ matrix.python }} enable-cache: false - name: Run unit tests run: | - uv run -p 3.10 --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report.xml + uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml - name: Upload test results if: always() uses: actions/upload-artifact@v4 with: - name: test-results + name: test-results-${{ matrix.python }} path: | .pytest_cache/ - pytest-report.xml + pytest-report-${{ matrix.python }}.xml retention-days: 7 From 4eee349acd2e79c82812659a177c8e2735498529 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 12 Mar 2025 14:07:28 -0400 Subject: [PATCH 5/7] fix: respect log_level in uvicorn and third party libs (#1524) # What does this PR do? uvicorn has a `log_level` arg in uvicorn.run, pass in the effective level set by the logger. Additionally, third party libraries like httpx are using our logging format, but not honoring our log level. This seems unintended, so loop through all items in the loggerDict and apply the same log level as what we have set. ## Test Plan before: ``` llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml Environment variable LLAMA_STACK_LOGGING found: all=warn Using virtual environment: /Users/charliedoern/projects/Documents/llama-stack/venv + python -m llama_stack.distribution.server.server --yaml-config /Users/charliedoern/.llama/distributions/ollama/ollama-run.yaml --port 8321 Environment variable LLAMA_STACK_LOGGING found: all=warn WARNING 2025-03-10 16:05:49,706 root:71 uncategorized: Warning: `bwrap` is not available. Code interpreter tool will not work correctly. INFO 2025-03-10 16:05:49,916 datasets:54 uncategorized: PyTorch version 2.5.1 available. INFO 2025-03-10 16:05:50,010 httpx:1740 uncategorized: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK" INFO 2025-03-10 16:05:50,297 httpx:1740 uncategorized: HTTP Request: POST http://localhost:11434/api/pull "HTTP/1.1 200 OK" INFO 2025-03-10 16:05:50,314 httpx:1740 uncategorized: HTTP Request: GET http://localhost:11434/api/tags "HTTP/1.1 200 OK" INFO: Started server process [89663] INFO: Waiting for application startup. INFO: ASGI 'lifespan' protocol appears unsupported. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) ``` after: ``` llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml Environment variable LLAMA_STACK_LOGGING found: all=warn Using virtual environment: /Users/charliedoern/projects/Documents/llama-stack/venv + python -m llama_stack.distribution.server.server --yaml-config /Users/charliedoern/.llama/distributions/ollama/ollama-run.yaml --port 8321 Environment variable LLAMA_STACK_LOGGING found: all=warn WARNING 2025-03-10 16:05:20,429 root:71 uncategorized: Warning: `bwrap` is not available. Code interpreter tool will not work correctly. INFO 2025-03-10 16:05:20,639 datasets:54 uncategorized: PyTorch version 2.5.1 available. ``` Signed-off-by: Charlie Doern --- llama_stack/distribution/server/server.py | 1 + llama_stack/log.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ea8723365..2cc70a738 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -422,6 +422,7 @@ def main(): "host": listen_host, "port": port, "lifespan": "on", + "log_level": logger.getEffectiveLevel(), } if ssl_config: uvicorn_config.update(ssl_config) diff --git a/llama_stack/log.py b/llama_stack/log.py index 80ee9fa1b..572dea234 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -170,6 +170,11 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None } dictConfig(logging_config) + # Ensure third-party libraries follow the root log level + for _, logger in logging.root.manager.loggerDict.items(): + if isinstance(logger, logging.Logger): + logger.setLevel(root_level) + def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter: """ From 0b0be70605f4497f817433a1484bde0b202efb18 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 12 Mar 2025 11:12:08 -0700 Subject: [PATCH 6/7] feat: Add open benchmark template codegen (#1579) ## What does this PR do? As title, add codegen for open-benchmark template ## test checked the new generated run.yaml file and it's identical before and after the change Also add small improvement to together template so that missing TOGETHER_API_KEY won't crash the server which is the consistent user experience as other remote providers --- distributions/dependencies.json | 34 ++ .../remote/inference/together/config.py | 2 +- .../templates/open-benchmark/__init__.py | 7 + .../open-benchmark/open_benchmark.py | 293 ++++++++++++++++++ llama_stack/templates/open-benchmark/run.yaml | 168 +++++----- llama_stack/templates/template.py | 6 + .../templates/together/run-with-safety.yaml | 2 +- llama_stack/templates/together/run.yaml | 2 +- 8 files changed, 430 insertions(+), 84 deletions(-) create mode 100644 llama_stack/templates/open-benchmark/__init__.py create mode 100644 llama_stack/templates/open-benchmark/open_benchmark.py diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 97aecc719..82fbcec8d 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -453,6 +453,40 @@ "transformers", "uvicorn" ], + "open-benchmark": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "fastapi", + "fire", + "httpx", + "litellm", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "sqlite-vec", + "together", + "tqdm", + "transformers", + "uvicorn" + ], "remote-vllm": [ "aiosqlite", "autoevals", diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index fda3b8f43..fa7c45c9f 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel): def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "url": "https://api.together.xyz/v1", - "api_key": "${env.TOGETHER_API_KEY}", + "api_key": "${env.TOGETHER_API_KEY:}", } diff --git a/llama_stack/templates/open-benchmark/__init__.py b/llama_stack/templates/open-benchmark/__init__.py new file mode 100644 index 000000000..14d0a28f5 --- /dev/null +++ b/llama_stack/templates/open-benchmark/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .open_benchmark import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py new file mode 100644 index 000000000..7df33a715 --- /dev/null +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -0,0 +1,293 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Tuple + +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ( + BenchmarkInput, + DatasetInput, + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig +from llama_stack.providers.remote.inference.gemini.config import GeminiConfig +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.together.config import TogetherImplConfig +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry + + +def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: + # in this template, we allow each API key to be optional + providers = [ + ( + "openai", + [ + ProviderModelEntry( + provider_model_id="openai/gpt-4o", + model_type=ModelType.llm, + ) + ], + OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"), + ), + ( + "anthropic", + [ + ProviderModelEntry( + provider_model_id="anthropic/claude-3-5-sonnet-latest", + model_type=ModelType.llm, + ) + ], + AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"), + ), + ( + "gemini", + [ + ProviderModelEntry( + provider_model_id="gemini/gemini-1.5-flash", + model_type=ModelType.llm, + ) + ], + GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), + ), + ( + "groq", + [], + GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), + ), + ( + "together", + [], + TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"), + ), + ] + inference_providers = [] + available_models = {} + for provider_id, model_entries, config in providers: + inference_providers.append( + Provider( + provider_id=provider_id, + provider_type=f"remote::{provider_id}", + config=config, + ) + ) + available_models[provider_id] = model_entries + return inference_providers, available_models + + +def get_distribution_template() -> DistributionTemplate: + inference_providers, available_models = get_inference_providers() + providers = { + "inference": [p.provider_type for p in inference_providers], + "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + name = "open-benchmark" + + vector_io_providers = [ + Provider( + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ), + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + Provider( + provider_id="${env.ENABLE_PGVECTOR+pgvector}", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig.sample_run_config( + db="${env.PGVECTOR_DB:}", + user="${env.PGVECTOR_USER:}", + password="${env.PGVECTOR_PASSWORD:}", + ), + ), + ] + + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + + default_models = get_model_registry(available_models) + [ + ModelInput( + model_id="meta-llama/Llama-3.3-70B-Instruct", + provider_id="groq", + provider_model_id="groq/llama-3.3-70b-versatile", + model_type=ModelType.llm, + ), + ModelInput( + model_id="meta-llama/Llama-3.1-405B-Instruct", + provider_id="together", + provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + model_type=ModelType.llm, + ), + ] + + default_datasets = [ + DatasetInput( + dataset_id="simpleqa", + provider_id="huggingface", + url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"}, + metadata={ + "path": "llamastack/simpleqa", + "split": "train", + }, + dataset_schema={ + "input_query": {"type": "string"}, + "expected_answer": {"type": "string"}, + "chat_completion_input": {"type": "string"}, + }, + ), + DatasetInput( + dataset_id="mmlu_cot", + provider_id="huggingface", + url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"}, + metadata={ + "path": "llamastack/mmlu_cot", + "name": "all", + "split": "test", + }, + dataset_schema={ + "input_query": {"type": "string"}, + "expected_answer": {"type": "string"}, + "chat_completion_input": {"type": "string"}, + }, + ), + DatasetInput( + dataset_id="gpqa_cot", + provider_id="huggingface", + url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"}, + metadata={ + "path": "llamastack/gpqa_0shot_cot", + "name": "gpqa_main", + "split": "train", + }, + dataset_schema={ + "input_query": {"type": "string"}, + "expected_answer": {"type": "string"}, + "chat_completion_input": {"type": "string"}, + }, + ), + DatasetInput( + dataset_id="math_500", + provider_id="huggingface", + url={"uri": "https://huggingface.co/datasets/llamastack/math_500"}, + metadata={ + "path": "llamastack/math_500", + "split": "test", + }, + dataset_schema={ + "input_query": {"type": "string"}, + "expected_answer": {"type": "string"}, + "chat_completion_input": {"type": "string"}, + }, + ), + ] + + default_benchmarks = [ + BenchmarkInput( + benchmark_id="meta-reference-simpleqa", + dataset_id="simpleqa", + scoring_functions=["llm-as-judge::405b-simpleqa"], + ), + BenchmarkInput( + benchmark_id="meta-reference-mmlu-cot", + dataset_id="mmlu_cot", + scoring_functions=["basic::regex_parser_multiple_choice_answer"], + ), + BenchmarkInput( + benchmark_id="meta-reference-gpqa-cot", + dataset_id="gpqa_cot", + scoring_functions=["basic::regex_parser_multiple_choice_answer"], + ), + BenchmarkInput( + benchmark_id="meta-reference-math-500", + dataset_id="math_500", + scoring_functions=["basic::regex_parser_math_response"], + ), + ] + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Distribution for running open benchmarks", + container_image=None, + template_path=None, + providers=providers, + available_models_by_provider=available_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": inference_providers, + "vector_io": vector_io_providers, + }, + default_models=default_models, + default_tool_groups=default_tool_groups, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_datasets=default_datasets, + default_benchmarks=default_benchmarks, + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "TOGETHER_API_KEY": ( + "", + "Together API Key", + ), + "OPENAI_API_KEY": ( + "", + "OpenAI API Key", + ), + "GEMINI_API_KEY": ( + "", + "Gemini API Key", + ), + "ANTHROPIC_API_KEY": ( + "", + "Anthropic API Key", + ), + "GROQ_API_KEY": ( + "", + "Groq API Key", + ), + }, + ) diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 736b47746..97c54e621 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -38,7 +38,7 @@ providers: - provider_id: sqlite-vec provider_type: inline::sqlite-vec config: - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db - provider_id: ${env.ENABLE_CHROMADB+chromadb} provider_type: remote::chromadb config: @@ -62,14 +62,14 @@ providers: persistence_store: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference config: service_name: ${env.OTEL_SERVICE_NAME:llama-stack} sinks: ${env.TELEMETRY_SINKS:console,sqlite} - sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db} eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -114,18 +114,13 @@ providers: config: {} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db models: - metadata: {} model_id: openai/gpt-4o provider_id: openai provider_model_id: openai/gpt-4o model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct - provider_id: together - provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo - model_type: llm - metadata: {} model_id: anthropic/claude-3-5-sonnet-latest provider_id: anthropic @@ -141,84 +136,95 @@ models: provider_id: groq provider_model_id: groq/llama-3.3-70b-versatile model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct + provider_id: together + provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo + model_type: llm shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] datasets: - - dataset_id: simpleqa - provider_id: huggingface - url: - uri: https://huggingface.co/datasets/llamastack/simpleqa - metadata: - path: llamastack/simpleqa - name: - split: train - dataset_schema: - input_query: - type: string - expected_answer: - type: string - chat_completion_input: - type: string - - dataset_id: mmlu_cot - provider_id: huggingface - url: - uri: https://huggingface.co/datasets/llamastack/mmlu_cot - metadata: - path: llamastack/mmlu_cot - name: all - split: test - dataset_schema: - input_query: - type: string - expected_answer: - type: string - chat_completion_input: - type: string - - dataset_id: gpqa_cot - provider_id: huggingface - url: - uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot - metadata: - path: llamastack/gpqa_0shot_cot - name: gpqa_main - split: train - dataset_schema: - input_query: - type: string - expected_answer: - type: string - chat_completion_input: - type: string - - dataset_id: math_500 - provider_id: huggingface - url: - uri: https://huggingface.co/datasets/llamastack/math_500 - metadata: - path: llamastack/math_500 - name: - split: test - dataset_schema: - input_query: - type: string - expected_answer: - type: string - chat_completion_input: - type: string +- dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string + url: + uri: https://huggingface.co/datasets/llamastack/simpleqa + metadata: + path: llamastack/simpleqa + split: train + dataset_id: simpleqa + provider_id: huggingface +- dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string + url: + uri: https://huggingface.co/datasets/llamastack/mmlu_cot + metadata: + path: llamastack/mmlu_cot + name: all + split: test + dataset_id: mmlu_cot + provider_id: huggingface +- dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string + url: + uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot + metadata: + path: llamastack/gpqa_0shot_cot + name: gpqa_main + split: train + dataset_id: gpqa_cot + provider_id: huggingface +- dataset_schema: + input_query: + type: string + expected_answer: + type: string + chat_completion_input: + type: string + url: + uri: https://huggingface.co/datasets/llamastack/math_500 + metadata: + path: llamastack/math_500 + split: test + dataset_id: math_500 + provider_id: huggingface scoring_fns: [] benchmarks: - - benchmark_id: meta-reference-simpleqa - dataset_id: simpleqa - scoring_functions: ["llm-as-judge::405b-simpleqa"] - - benchmark_id: meta-reference-mmlu-cot - dataset_id: mmlu_cot - scoring_functions: ["basic::regex_parser_multiple_choice_answer"] - - benchmark_id: meta-reference-gpqa-cot - dataset_id: gpqa_cot - scoring_functions: ["basic::regex_parser_multiple_choice_answer"] - - benchmark_id: meta-reference-math-500 - dataset_id: math_500 - scoring_functions: ["basic::regex_parser_math_response"] +- dataset_id: simpleqa + scoring_functions: + - llm-as-judge::405b-simpleqa + metadata: {} + benchmark_id: meta-reference-simpleqa +- dataset_id: mmlu_cot + scoring_functions: + - basic::regex_parser_multiple_choice_answer + metadata: {} + benchmark_id: meta-reference-mmlu-cot +- dataset_id: gpqa_cot + scoring_functions: + - basic::regex_parser_multiple_choice_answer + metadata: {} + benchmark_id: meta-reference-gpqa-cot +- dataset_id: math_500 + scoring_functions: + - basic::regex_parser_math_response + metadata: {} + benchmark_id: meta-reference-math-500 tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index a7b862396..aa1ce144f 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -14,7 +14,9 @@ from pydantic import BaseModel, Field from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( Api, + BenchmarkInput, BuildConfig, + DatasetInput, DistributionSpec, ModelInput, Provider, @@ -56,6 +58,8 @@ class RunConfigSettings(BaseModel): default_models: Optional[List[ModelInput]] = None default_shields: Optional[List[ShieldInput]] = None default_tool_groups: Optional[List[ToolGroupInput]] = None + default_datasets: Optional[List[DatasetInput]] = None + default_benchmarks: Optional[List[BenchmarkInput]] = None def run_config( self, @@ -113,6 +117,8 @@ class RunConfigSettings(BaseModel): models=self.default_models or [], shields=self.default_shields or [], tool_groups=self.default_tool_groups or [], + datasets=self.default_datasets or [], + benchmarks=self.default_benchmarks or [], ) diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index fd74f80c3..3a7d3dfba 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -16,7 +16,7 @@ providers: provider_type: remote::together config: url: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} + api_key: ${env.TOGETHER_API_KEY:} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 9a717290a..10668914a 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -16,7 +16,7 @@ providers: provider_type: remote::together config: url: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} + api_key: ${env.TOGETHER_API_KEY:} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} From 90ca4d94dea4f562249983b3854f094704356b76 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 12 Mar 2025 11:16:17 -0700 Subject: [PATCH 7/7] fix: fix passthrough inference provider to make it work for agent (#1577) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? We noticed that the passthrough inference provider doesn't work agent due to the type mis-match between client and server. We manually cast the llama stack client type to llama stack server type to fix the issue. ## test run `python -m examples.agents.hello localhost 8321` within llama-stack-apps Screenshot 2025-03-11 at 8 43 44 PM fix https://github.com/meta-llama/llama-stack/issues/1560 --- .../inference/passthrough/passthrough.py | 84 ++++++++++++++++--- 1 file changed, 71 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index aa8a87bf7..8f3a0d147 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,12 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack_client import LlamaStackClient +from llama_stack_client import AsyncLlamaStackClient from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -24,6 +26,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .config import PassthroughImplConfig @@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference): async def register_model(self, model: Model) -> Model: return model - def _get_client(self) -> LlamaStackClient: + def _get_client(self) -> AsyncLlamaStackClient: passthrough_url = None passthrough_api_key = None provider_data = None @@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference): ) passthrough_api_key = provider_data.passthrough_api_key - return LlamaStackClient( + return AsyncLlamaStackClient( base_url=passthrough_url, api_key=passthrough_api_key, provider_data=provider_data, @@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference): client = self._get_client() model = await self.model_store.get_model(model_id) - params = { + request_params = { "model_id": model.provider_resource_id, "content": content, "sampling_params": sampling_params, @@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference): "logprobs": logprobs, } - params = {key: value for key, value in params.items() if value is not None} + request_params = {key: value for key, value in request_params.items() if value is not None} + + # cast everything to json dict + json_params = self.cast_value_to_json_dict(request_params) # only pass through the not None params - return client.inference.completion(**params) + return await client.inference.completion(**json_params) async def chat_completion( self, @@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference): ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() - client = self._get_client() model = await self.model_store.get_model(model_id) - params = { + # TODO: revisit this remove tool_calls from messages logic + for message in messages: + if hasattr(message, "tool_calls"): + message.tool_calls = None + + request_params = { "model_id": model.provider_resource_id, "messages": messages, "sampling_params": sampling_params, @@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference): "logprobs": logprobs, } - params = {key: value for key, value in params.items() if value is not None} - # only pass through the not None params - return client.inference.chat_completion(**params) + request_params = {key: value for key, value in request_params.items() if value is not None} + + # cast everything to json dict + json_params = self.cast_value_to_json_dict(request_params) + + if stream: + return self._stream_chat_completion(json_params) + else: + return await self._nonstream_chat_completion(json_params) + + async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse: + client = self._get_client() + response = await client.inference.chat_completion(**json_params) + + response = response.to_dict() + + # temporary hack to remove the metrics from the response + response["metrics"] = [] + + return convert_to_pydantic(ChatCompletionResponse, response) + + async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator: + client = self._get_client() + stream_response = await client.inference.chat_completion(**json_params) + + async for chunk in stream_response: + chunk = chunk.to_dict() + + # temporary hack to remove the metrics from the response + chunk["metrics"] = [] + chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk) + yield chunk async def embeddings( self, @@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference): client = self._get_client() model = await self.model_store.get_model(model_id) - return client.inference.embeddings( + return await client.inference.embeddings( model_id=model.provider_resource_id, contents=contents, text_truncation=text_truncation, output_dimension=output_dimension, task_type=task_type, ) + + def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]: + json_params = {} + for key, value in request_params.items(): + json_input = convert_pydantic_to_json_value(value) + if isinstance(json_input, dict): + json_input = {k: v for k, v in json_input.items() if v is not None} + elif isinstance(json_input, list): + json_input = [x for x in json_input if x is not None] + new_input = [] + for x in json_input: + if isinstance(x, dict): + x = {k: v for k, v in x.items() if v is not None} + new_input.append(x) + json_input = new_input + + json_params[key] = json_input + + return json_params