From 3c72c034e6ef526aed8c4e4dadb0369bd30f8bb0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 27 Dec 2024 15:45:44 -0800 Subject: [PATCH] [remove import *] clean up import *'s (#689) # What does this PR do? - as title, cleaning up `import *`'s - upgrade tests to make them more robust to bad model outputs - remove import *'s in llama_stack/apis/* (skip __init__ modules) image - run `sh run_openapi_generator.sh`, no types gets affected ## Test Plan ### Providers Tests **agents** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` **inference** ```bash # meta-reference torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py # together pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py ``` **safety** ``` pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B ``` **memory** ``` pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 ``` **scoring** ``` pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` **datasetio** ``` pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py ``` **eval** ``` pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py ``` ### Client-SDK Tests ``` LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk ``` ### llama-stack-apps ``` PORT=5000 LOCALHOST=localhost python -m examples.agents.hello $LOCALHOST $PORT python -m examples.agents.inflation $LOCALHOST $PORT python -m examples.agents.podcast_transcript $LOCALHOST $PORT python -m examples.agents.rag_as_attachments $LOCALHOST $PORT python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT # Vision model python -m examples.interior_design_assistant.app python -m examples.agent_store.app $LOCALHOST $PORT ``` ### CLI ``` which llama llama model prompt-format -m Llama3.2-11B-Vision-Instruct llama model list llama stack list-apis llama stack list-providers inference llama stack build --template ollama --image-type conda ``` ### Distributions Tests **ollama** ``` llama stack build --template ollama --image-type conda ollama run llama3.2:1b-instruct-fp16 llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct ``` **fireworks** ``` llama stack build --template fireworks --image-type conda llama stack run ./llama_stack/templates/fireworks/run.yaml ``` **together** ``` llama stack build --template together --image-type conda llama stack run ./llama_stack/templates/together/run.yaml ``` **tgi** ``` llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/zero_to_hero_guide/06_Safety101.ipynb | 4 +- llama_stack/apis/agents/agents.py | 24 ++++++-- llama_stack/apis/agents/event_logger.py | 5 +- .../apis/batch_inference/batch_inference.py | 12 +++- llama_stack/apis/datasetio/datasetio.py | 2 +- llama_stack/apis/eval/eval.py | 12 ++-- llama_stack/apis/inference/inference.py | 5 +- .../apis/post_training/post_training.py | 8 +-- llama_stack/apis/scoring/scoring.py | 5 +- .../synthetic_data_generation.py | 3 +- llama_stack/cli/model/safety_models.py | 7 ++- llama_stack/cli/stack/build.py | 15 +++-- llama_stack/distribution/build.py | 11 ++-- llama_stack/distribution/configure.py | 15 ++--- llama_stack/distribution/datatypes.py | 16 ++--- llama_stack/distribution/inspect.py | 6 +- llama_stack/distribution/resolver.py | 30 ++++++++-- llama_stack/distribution/routers/__init__.py | 6 +- llama_stack/distribution/routers/routers.py | 43 ++++++++++---- .../distribution/routers/routing_tables.py | 39 +++++++++--- llama_stack/distribution/server/server.py | 17 +++--- llama_stack/distribution/stack.py | 39 ++++++------ llama_stack/distribution/store/registry.py | 7 +-- .../distribution/store/tests/test_registry.py | 7 ++- .../agents/meta_reference/agent_instance.py | 59 ++++++++++++++++--- .../inline/agents/meta_reference/agents.py | 17 +++++- .../agents/meta_reference/persistence.py | 4 +- .../meta_reference/rag/context_retriever.py | 4 +- .../inline/agents/meta_reference/safety.py | 4 +- .../meta_reference/tests/test_chat_agent.py | 24 ++++++-- .../agents/meta_reference/tools/safety.py | 2 +- .../inline/datasetio/localfs/config.py | 2 +- .../inline/datasetio/localfs/datasetio.py | 13 ++-- .../inline/eval/meta_reference/eval.py | 13 ++-- .../inline/inference/meta_reference/config.py | 5 +- .../inference/meta_reference/generation.py | 18 +++--- .../providers/inline/inference/vllm/vllm.py | 25 ++++++-- .../providers/inline/memory/faiss/faiss.py | 11 ++-- .../post_training/torchtune/common/utils.py | 5 +- .../post_training/torchtune/post_training.py | 17 +++++- .../recipes/lora_finetuning_single_device.py | 26 +++++--- .../safety/code_scanner/code_scanner.py | 8 ++- .../inline/safety/llama_guard/llama_guard.py | 20 ++++++- .../safety/prompt_guard/prompt_guard.py | 13 ++-- .../providers/inline/scoring/basic/scoring.py | 17 +++--- .../inline/scoring/braintrust/braintrust.py | 21 ++++--- .../inline/scoring/braintrust/config.py | 4 +- .../telemetry/meta_reference/telemetry.py | 20 +++++-- .../inline/telemetry/sample/sample.py | 4 +- llama_stack/providers/registry/agents.py | 8 ++- llama_stack/providers/registry/datasetio.py | 8 ++- llama_stack/providers/registry/eval.py | 2 +- llama_stack/providers/registry/inference.py | 9 ++- llama_stack/providers/registry/memory.py | 9 ++- .../providers/registry/post_training.py | 2 +- llama_stack/providers/registry/safety.py | 2 +- llama_stack/providers/registry/scoring.py | 2 +- llama_stack/providers/registry/telemetry.py | 8 ++- .../providers/registry/tool_runtime.py | 2 +- .../providers/remote/agents/sample/sample.py | 4 +- .../datasetio/huggingface/huggingface.py | 6 +- .../remote/inference/bedrock/bedrock.py | 25 ++++++-- .../remote/inference/cerebras/cerebras.py | 22 +++++-- .../remote/inference/databricks/databricks.py | 17 +++++- .../remote/inference/fireworks/fireworks.py | 19 +++++- .../remote/inference/ollama/ollama.py | 28 +++++++-- .../remote/inference/sample/sample.py | 5 +- .../providers/remote/inference/tgi/tgi.py | 21 ++++++- .../remote/inference/together/together.py | 19 +++++- .../providers/remote/inference/vllm/vllm.py | 22 ++++++- .../providers/remote/memory/chroma/chroma.py | 10 +++- .../remote/memory/pgvector/pgvector.py | 12 +++- .../providers/remote/memory/qdrant/qdrant.py | 13 ++-- .../providers/remote/memory/sample/sample.py | 5 +- .../remote/memory/weaviate/weaviate.py | 10 +++- .../remote/safety/bedrock/bedrock.py | 11 +++- .../providers/remote/safety/sample/sample.py | 5 +- .../providers/tests/agents/test_agents.py | 24 +++++++- .../tests/agents/test_persistence.py | 6 +- .../tests/datasetio/test_datasetio.py | 13 ++-- llama_stack/providers/tests/eval/test_eval.py | 4 +- .../tests/inference/test_prompt_adapter.py | 20 ++++--- .../tests/inference/test_text_inference.py | 29 +++++++-- .../tests/inference/test_vision_inference.py | 11 +++- .../providers/tests/memory/fixtures.py | 5 +- .../providers/tests/memory/test_memory.py | 12 ++-- .../providers/tests/post_training/fixtures.py | 3 +- .../tests/post_training/test_post_training.py | 15 ++++- llama_stack/providers/tests/resolver.py | 14 ++++- .../providers/tests/safety/test_safety.py | 6 +- .../providers/tests/scoring/test_scoring.py | 2 +- .../utils/inference/openai_compat.py | 19 ++++-- .../providers/utils/kvstore/kvstore.py | 6 +- .../providers/utils/kvstore/redis/redis.py | 2 +- .../providers/utils/kvstore/sqlite/sqlite.py | 2 +- .../providers/utils/memory/vector_store.py | 13 ++-- .../utils/scoring/aggregation_utils.py | 3 +- .../providers/utils/telemetry/tracing.py | 14 ++++- tests/client-sdk/agents/test_agents.py | 43 +++++++++----- 99 files changed, 907 insertions(+), 359 deletions(-) diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index 6b5bd53bf..e2ba5e22e 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -67,7 +67,7 @@ "from termcolor import cprint\n", "\n", "from llama_stack.distribution.datatypes import RemoteProviderConfig\n", - "from llama_stack.apis.safety import * # noqa: F403\n", + "from llama_stack.apis.safety import Safety\n", "from llama_stack_client import LlamaStackClient\n", "\n", "\n", @@ -127,7 +127,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5fd90ae7a..5748b4e41 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,18 +18,30 @@ from typing import ( Union, ) +from llama_models.llama3.api.datatypes import ToolParamDefinition + from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.common.deployment_types import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig +from llama_stack.apis.inference import ( + CompletionMessage, + SamplingParams, + ToolCall, + ToolCallDelta, + ToolChoice, + ToolPromptFormat, + ToolResponse, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.safety import SafetyViolation + +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 4c379999e..40a69d19c 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -6,13 +6,14 @@ from typing import Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.tool_utils import ToolUtils - from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType +from llama_stack.apis.inference import ToolResponseMessage + class LogEvent: def __init__( diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 358cf3c35..f7b8b4387 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import ( + CompletionMessage, + InterleavedContent, + LogProbConfig, + Message, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) @json_schema_type diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 22acc3211..983e0e4ea 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasets import Dataset @json_schema_type diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 2e0ce1fbc..2592bca37 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union + +from llama_models.llama3.api.datatypes import BaseModel, Field +from llama_models.schema_utils import json_schema_type, webmethod from typing_extensions import Annotated -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.schema_utils import json_schema_type, webmethod -from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.apis.inference import SamplingParams, SystemMessage +from llama_stack.apis.scoring import ScoringResult +from llama_stack.apis.scoring_functions import ScoringFnParams @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 28b9d9106..e48042091 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -7,7 +7,9 @@ from enum import Enum from typing import ( + Any, AsyncIterator, + Dict, List, Literal, Optional, @@ -32,8 +34,9 @@ from typing_extensions import Annotated from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.models import Model + from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.apis.models import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index fdbaa364d..1c2d2d6e2 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -7,17 +7,17 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.common.content_types import URL + from llama_stack.apis.common.job_types import JobStatus -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.common.training_types import * # noqa: F403 +from llama_stack.apis.common.training_types import Checkpoint @json_schema_type diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a47620a3d..453e35f6d 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams # mapping of metric to value diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 4ffaa4d1e..13b209912 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -6,13 +6,12 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import Message diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 39c133f73..9464e0a2d 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,11 +6,12 @@ from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field - -from llama_models.datatypes import * # noqa: F403 +from llama_models.datatypes import CheckpointQuantizationFormat +from llama_models.llama3.api.datatypes import SamplingParams from llama_models.sku_list import LlamaDownloadInfo +from pydantic import BaseModel, ConfigDict, Field + class PromptGuardModel(BaseModel): """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index f18d262c0..54d78ad93 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -3,21 +3,28 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - import argparse - -from llama_stack.cli.subcommand import Subcommand -from llama_stack.distribution.datatypes import * # noqa: F403 import os import shutil from functools import lru_cache from pathlib import Path +from typing import List, Optional import pkg_resources +from llama_stack.cli.subcommand import Subcommand + +from llama_stack.distribution.datatypes import ( + BuildConfig, + DistributionSpec, + Provider, + StackRunConfig, +) + from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.providers.datatypes import Api TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index bdda0349f..f376301f9 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -6,21 +6,22 @@ import logging from enum import Enum -from typing import List + +from pathlib import Path +from typing import Dict, List import pkg_resources from pydantic import BaseModel from termcolor import cprint -from llama_stack.distribution.utils.exec import run_with_pty - -from llama_stack.distribution.datatypes import * # noqa: F403 -from pathlib import Path +from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR +from llama_stack.distribution.utils.exec import run_with_pty +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index a4d0f970b..71c2676de 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -6,10 +6,14 @@ import logging import textwrap -from typing import Any - -from llama_stack.distribution.datatypes import * # noqa: F403 +from typing import Any, Dict +from llama_stack.distribution.datatypes import ( + DistributionSpec, + LLAMA_STACK_RUN_CONFIG_VERSION, + Provider, + StackRunConfig, +) from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, @@ -17,10 +21,7 @@ from llama_stack.distribution.distribution import ( from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config - -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.providers.datatypes import Api, ProviderSpec logger = logging.getLogger(__name__) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f2dea6012..dec62bfae 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -4,24 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTaskInput +from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput +from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput +from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime -from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index f5716ef5e..dbb16d8ce 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -5,12 +5,12 @@ # the root directory of this source tree. from typing import Dict, List -from llama_stack.apis.inspect import * # noqa: F403 + from pydantic import BaseModel +from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo +from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.server.endpoints import get_all_api_endpoints -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 class DistributionInspectConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 439971315..0a6eed345 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -6,14 +6,10 @@ import importlib import inspect -from typing import Any, Dict, List, Set - - -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 - import logging +from typing import Any, Dict, List, Set + from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -32,10 +28,32 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl + +from llama_stack.distribution.datatypes import ( + AutoRoutedProviderSpec, + Provider, + RoutingTableProviderSpec, + StackRunConfig, +) from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.providers.datatypes import ( + Api, + DatasetsProtocolPrivate, + EvalTasksProtocolPrivate, + InlineProviderSpec, + MemoryBanksProtocolPrivate, + ModelsProtocolPrivate, + ProviderSpec, + RemoteProviderConfig, + RemoteProviderSpec, + ScoringFunctionsProtocolPrivate, + ShieldsProtocolPrivate, + ToolsProtocolPrivate, +) + log = logging.getLogger(__name__) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 693f1fbe2..f19a2bffc 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Dict + +from llama_stack.distribution.datatypes import RoutedProtocol -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable from .routing_tables import ( DatasetsRoutingTable, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a25a848db..84ef467eb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,16 +6,40 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasetio.datasetio import DatasetIO -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.eval import ( + AppEvalTaskConfig, + Eval, + EvalTaskConfig, + EvaluateResponse, + Job, + JobStatus, +) +from llama_stack.apis.inference import ( + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse from llama_stack.apis.memory_banks.memory_banks import BankParams -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutingTable +from llama_stack.apis.models import ModelType +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringFnParams, +) +from llama_stack.apis.shields import Shield +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime +from llama_stack.providers.datatypes import RoutingTable class MemoryRouter(Memory): @@ -330,7 +354,6 @@ class EvalRouter(Eval): task_config=task_config, ) - @webmethod(route="/eval/evaluate_rows", method="POST") async def evaluate_rows( self, task_id: str, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3fb086b72..ab1becfdd 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,19 +6,42 @@ from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import parse_obj_as from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.datasets import Dataset, Datasets +from llama_stack.apis.eval_tasks import EvalTask, EvalTasks +from llama_stack.apis.memory_banks import ( + BankParams, + MemoryBank, + MemoryBanks, + MemoryBankType, +) +from llama_stack.apis.models import Model, Models, ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.apis.shields import Shield, Shields +from llama_stack.apis.tools import ( + MCPToolGroupDef, + Tool, + ToolGroup, + ToolGroupDef, + ToolGroups, + UserDefinedToolGroupDef, +) +from llama_stack.distribution.datatypes import ( + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) + from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8f24f3eaf..daaf8475b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,14 +28,9 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.datatypes import StackRunConfig -from llama_stack.providers.utils.telemetry.tracing import ( - end_trace, - setup_logger, - start_trace, -) -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( @@ -43,11 +38,19 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) + +from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) +from llama_stack.providers.utils.telemetry.tracing import ( + end_trace, + setup_logger, + start_trace, +) + from .endpoints import get_all_api_endpoints diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index f5180b0db..965df5f03 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -8,32 +8,31 @@ import logging import os import re from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional import pkg_resources import yaml from termcolor import colored -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.batch_inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.telemetry import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.apis.synthetic_data_generation import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.inspect import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_stack.apis.agents import Agents +from llama_stack.apis.batch_inference import BatchInference +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTasks +from llama_stack.apis.inference import Inference +from llama_stack.apis.inspect import Inspect +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.models import Models +from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import ScoringFunctions +from llama_stack.apis.shields import Shields +from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration +from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index f98c14443..686054dd2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -13,11 +13,8 @@ import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import ( - KVStore, - kvstore_impl, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class DistributionRegistry(Protocol): diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 7e389cccd..54bc04f9c 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -8,11 +8,14 @@ import os import pytest import pytest_asyncio -from llama_stack.distribution.store import * # noqa F403 from llama_stack.apis.inference import Model from llama_stack.apis.memory_banks import VectorMemoryBank + +from llama_stack.distribution.store.registry import ( + CachedDiskDistributionRegistry, + DiskDistributionRegistry, +) from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig -from llama_stack.distribution.datatypes import * # noqa F403 @pytest.fixture diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index d7930550d..f225f5393 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,19 +13,64 @@ import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx +from llama_models.llama3.api.datatypes import BuiltinTool -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.agents import ( + AgentConfig, + AgentTool, + AgentTurnCreateRequest, + AgentTurnResponseEvent, + AgentTurnResponseEventType, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseStepProgressPayload, + AgentTurnResponseStepStartPayload, + AgentTurnResponseStreamChunk, + AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnStartPayload, + Attachment, + CodeInterpreterToolDefinition, + FunctionCallToolDefinition, + InferenceStep, + MemoryRetrievalStep, + MemoryToolDefinition, + PhotogenToolDefinition, + SearchToolDefinition, + ShieldCallStep, + StepType, + ToolExecutionStep, + Turn, + WolframAlphaToolDefinition, +) -from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.inference import ( + ChatCompletionResponseEventType, + CompletionMessage, + Inference, + Message, + SamplingParams, + StopReason, + SystemMessage, + ToolCallDelta, + ToolCallParseStatus, + ToolChoice, + ToolDefinition, + ToolResponse, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.safety import Safety from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index dec5ec960..93bfab5f4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -9,15 +9,26 @@ import logging import shutil import tempfile import uuid -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from termcolor import colored -from llama_stack.apis.inference import Inference +from llama_stack.apis.agents import ( + AgentConfig, + AgentCreateResponse, + Agents, + AgentSessionCreateResponse, + AgentStepResponse, + AgentTurnCreateRequest, + Attachment, + Session, + Turn, +) + +from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety -from llama_stack.apis.agents import * # noqa: F403 from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 1c99e3d75..a4b1af616 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,9 +10,11 @@ import uuid from datetime import datetime from typing import List, Optional -from llama_stack.apis.agents import * # noqa: F403 + from pydantic import BaseModel +from llama_stack.apis.agents import Turn + from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index 7b5c8b4b0..74eb91c53 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -7,8 +7,6 @@ from typing import List from jinja2 import Template -from llama_models.llama3.api import * # noqa: F403 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, @@ -16,7 +14,7 @@ from llama_stack.apis.agents import ( MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import Message, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 8fca4d310..90d193f90 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -9,7 +9,9 @@ import logging from typing import List -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 6edef0672..035054320 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union import pytest -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.agents import ( + AgentConfig, + AgentTurnCreateRequest, + AgentTurnResponseTurnCompletePayload, +) + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseStreamChunk, + CompletionMessage, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + UserMessage, +) +from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.safety import RunShieldResponse from ..agents import ( AGENT_INSTANCES_BY_ID, diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py index 1ffc99edd..a34649756 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py @@ -7,7 +7,7 @@ from typing import List from llama_stack.apis.inference import Message -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.safety import Safety from ..safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/inline/datasetio/localfs/config.py b/llama_stack/providers/inline/datasetio/localfs/config.py index 58d563c99..1b89df63b 100644 --- a/llama_stack/providers/inline/datasetio/localfs/config.py +++ b/llama_stack/providers/inline/datasetio/localfs/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.datasetio import * # noqa: F401, F403 +from pydantic import BaseModel class LocalFSDatasetIOConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 736e5d8b9..442053fb3 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,18 +3,19 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional - -import pandas -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.datasetio import * # noqa: F403 import base64 import os from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Any, Dict, List, Optional from urllib.parse import urlparse +import pandas + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasets import Dataset + from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index e1c2cc804..00630132e 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -5,13 +5,15 @@ # the root directory of this source tree. from enum import Enum from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 + from tqdm import tqdm -from .....apis.common.job_types import Job -from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.agents import Agents +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask @@ -20,6 +22,9 @@ from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus + from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 33af33fcd..2c46ef596 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -6,11 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import * # noqa: F403 - -from llama_stack.apis.inference import * # noqa: F401, F403 from pydantic import BaseModel, field_validator +from llama_stack.apis.inference import QuantizationConfig + from llama_stack.providers.utils.inference import supported_inference_models diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index c89183cb7..1807e4ad5 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model -from pydantic import BaseModel - -from llama_stack.apis.inference import * # noqa: F403 from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData +from pydantic import BaseModel + +from llama_stack.apis.inference import ( + Fp8QuantizationConfig, + Int4QuantizationConfig, + ResponseFormat, + ResponseFormatType, +) from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( CompletionRequestWithRawContent, ) -from .config import ( - Fp8QuantizationConfig, - Int4QuantizationConfig, - MetaReferenceInferenceConfig, - MetaReferenceQuantizedInferenceConfig, -) +from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index c5925774b..73f7adecd 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -7,10 +7,10 @@ import logging import os import uuid -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import * # noqa: F403 + from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model @@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index a46b151d9..af398801a 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -16,11 +16,14 @@ import faiss import numpy as np from numpy.typing import NDArray -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 462cbc21e..f2a2edae5 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -14,11 +14,10 @@ from enum import Enum from typing import Any, Callable, Dict, List import torch -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.common.type_system import * # noqa from llama_models.datatypes import Model from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.common.type_system import ParamType, StringType +from llama_stack.apis.datasets import Datasets from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 9b1269f16..90fbf7026 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -3,11 +3,26 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from datetime import datetime +from typing import Any, Dict, List, Optional + +from llama_models.schema_utils import webmethod + from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + DPOAlignmentConfig, + JobStatus, + LoraFinetuningConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) -from llama_stack.apis.post_training import * # noqa from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 71b8bf759..517be6d89 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,27 +14,33 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model +from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + LoraFinetuningConfig, + OptimizerConfig, + TrainingConfig, +) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR -from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( - TorchtuneCheckpointer, -) -from torch import nn -from torchtune import utils as torchtune_utils -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm -from llama_stack.apis.post_training import * # noqa + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.torchtune.common import utils +from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( + TorchtuneCheckpointer, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training +from torchtune import modules, training, utils as torchtune_utils from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss @@ -47,6 +53,8 @@ from torchtune.modules.peft import ( validate_missing_and_unexpected_for_lora, ) from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 46b5e57da..87d68f74c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -7,8 +7,14 @@ import logging from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bbdd5c3df..00213ac83 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -9,10 +9,24 @@ import re from string import Template from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 +from llama_models.datatypes import CoreModelId +from llama_models.llama3.api.datatypes import Role + from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.inference import ( + ChatCompletionResponseEventType, + Inference, + Message, + UserMessage, +) +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) + +from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ShieldsProtocolPrivate diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 4cb34127f..3f30645bd 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -11,11 +11,16 @@ import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer -from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 0c0503ff5..f8b30cbcf 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List +from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, +) +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from .config import BasicScoringConfig diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index ae9555403..0c6102645 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -3,20 +3,23 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 - import os +from typing import Any, Dict, List, Optional from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, + ScoringResultRow, +) +from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn + from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py index e12249432..d4e0d9bcd 100644 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ b/llama_stack/providers/inline/scoring/braintrust/config.py @@ -3,7 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring import * # noqa: F401, F403 +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field class BraintrustScoringConfig(BaseModel): diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index d7229f508..81dd9910d 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from llama_stack.apis.telemetry import ( + Event, + MetricEvent, + QueryCondition, + SpanEndPayload, + SpanStartPayload, + SpanStatus, + SpanWithStatus, + StructuredLogEvent, + Telemetry, + Trace, + UnstructuredLogEvent, +) + +from llama_stack.distribution.datatypes import Api + from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore -from llama_stack.apis.telemetry import * # noqa: F403 - -from llama_stack.distribution.datatypes import Api - from .config import TelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { diff --git a/llama_stack/providers/inline/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py index eaa6d834a..f07a185ef 100644 --- a/llama_stack/providers/inline/telemetry/sample/sample.py +++ b/llama_stack/providers/inline/telemetry/sample/sample.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.telemetry import Telemetry from .config import SampleConfig -from llama_stack.apis.telemetry import * # noqa: F403 - - class SampleTelemetryImpl(Telemetry): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 8b6c9027c..6595b1955 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) from llama_stack.providers.utils.kvstore import kvstore_dependencies diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 403c41111..f83dcbc60 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 718c7eae5..6901c3741 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0ff557b9f..397e8b7ee 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) META_REFERENCE_DEPS = [ "accelerate", diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c18bd3873..6867a9186 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) EMBEDDING_DEPS = [ "blobfile", diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index af8b660fa..3c5d06c05 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 99b0d2bd8..b9f7b6d78 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index f31ff44d7..ca09be984 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index d367bf894..ba7e2f806 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index f3e6aead8..042aef9d9 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/remote/agents/sample/sample.py b/llama_stack/providers/remote/agents/sample/sample.py index e9a3a6ee5..f8b312f1e 100644 --- a/llama_stack/providers/remote/agents/sample/sample.py +++ b/llama_stack/providers/remote/agents/sample/sample.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.agents import Agents from .config import SampleConfig -from llama_stack.apis.agents import * # noqa: F403 - - class SampleAgentsImpl(Agents): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 2fde7c3d0..47a63677e 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -5,11 +5,11 @@ # the root directory of this source tree. from typing import Any, Dict, List, Optional -from llama_stack.apis.datasetio import * # noqa: F403 - - import datasets as hf_datasets +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasets import Dataset + from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index ddf59fda8..d340bbbea 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import * # noqa: F403 import json +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -13,6 +13,24 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.bedrock.client import create_bedrock_client + from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -29,11 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from llama_stack.apis.inference import * # noqa: F403 - -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client - MODEL_ALIASES = [ build_model_alias( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2ff213c2e..40457e1ae 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -4,17 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import * # noqa: F403 - -from llama_models.datatypes import CoreModelId +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 155b230bb..3d88423c5 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional from llama_models.datatypes import CoreModelId @@ -14,7 +14,20 @@ from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 975ec4893..7a00194ac 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -11,7 +11,24 @@ from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import * # noqa: F403 + +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 920f3dd7e..88f985f3a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union import httpx from llama_models.datatypes import CoreModelId @@ -14,15 +14,33 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient +from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + TextContentItem, +) +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model, ModelType +from llama_stack.providers.datatypes import ModelsProtocolPrivate + from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, build_model_alias_with_just_provider_model_id, ModelRegistryHelper, ) - -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem -from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/remote/inference/sample/sample.py b/llama_stack/providers/remote/inference/sample/sample.py index 79ce1ffe4..51ce879eb 100644 --- a/llama_stack/providers/remote/inference/sample/sample.py +++ b/llama_stack/providers/remote/inference/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Model from .config import SampleConfig -from llama_stack.apis.inference import * # noqa: F403 - - class SampleInferenceImpl(Inference): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 5cc476fd7..dd02c055a 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -13,10 +13,25 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e12a2cc0a..6b5a6a3b0 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import CoreModelId @@ -14,7 +14,22 @@ from llama_models.llama3.api.tokenizer import Tokenizer from together import Together -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7250d901f..f62ccaa58 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -13,7 +13,25 @@ from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index aa8b481a3..c04d775ca 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -12,8 +12,14 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index ffe164ecb..b2c720b2c 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import psycopg2 from numpy.typing import NDArray @@ -14,8 +14,14 @@ from psycopg2.extras import execute_values, Json from pydantic import BaseModel, parse_obj_as -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index bf9e943c4..b1d5bd7fa 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -6,16 +6,21 @@ import logging import uuid -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate -from llama_stack.apis.memory import * # noqa: F403 - from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py index 09ea2f32c..b051eb544 100644 --- a/llama_stack/providers/remote/memory/sample/sample.py +++ b/llama_stack/providers/remote/memory/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBank from .config import SampleConfig -from llama_stack.apis.memory import * # noqa: F403 - - class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 8ee001cfa..f1433090d 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -14,8 +14,14 @@ from numpy.typing import NDArray from weaviate.classes.init import Auth from weaviate.classes.query import Filter -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 78e8105e0..fba7bf342 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -9,8 +9,15 @@ import logging from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client diff --git a/llama_stack/providers/remote/safety/sample/sample.py b/llama_stack/providers/remote/safety/sample/sample.py index 4069b8789..180e6c3b5 100644 --- a/llama_stack/providers/remote/safety/sample/sample.py +++ b/llama_stack/providers/remote/safety/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shield from .config import SampleConfig -from llama_stack.apis.safety import * # noqa: F403 - - class SampleSafetyImpl(Safety): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index ee2f3d29f..dc95fa6a6 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -5,11 +5,31 @@ # the root directory of this source tree. import os +from typing import Dict, List import pytest +from llama_models.llama3.api.datatypes import BuiltinTool -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.apis.agents import ( + AgentConfig, + AgentTool, + AgentTurnResponseEventType, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseStreamChunk, + AgentTurnResponseTurnCompletePayload, + Attachment, + MemoryToolDefinition, + SearchEngineType, + SearchToolDefinition, + ShieldCallStep, + StepType, + ToolChoice, + ToolExecutionStep, + Turn, +) +from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage +from llama_stack.apis.safety import ViolationLevel +from llama_stack.providers.datatypes import Api # How to run this test: # diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index 97094cd7a..38eb7de55 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -6,9 +6,9 @@ import pytest -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.datatypes import * # noqa: F403 - +from llama_stack.apis.agents import AgentConfig, Turn +from llama_stack.apis.inference import SamplingParams, UserMessage +from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .fixtures import pick_inference_model diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 7d88b6115..46c99f5b3 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os - -import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 import base64 import mimetypes +import os from pathlib import Path +import pytest + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType +from llama_stack.apis.datasets import Datasets + # How to run this test: # # pytest llama_stack/providers/tests/datasetio/test_datasetio.py diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 38da74128..d6794d488 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -7,8 +7,7 @@ import pytest -from llama_models.llama3.api import SamplingParams, URL - +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( @@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import ( BenchmarkEvalTaskConfig, ModelCandidate, ) +from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 2c222ffa1..4826e89d5 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,8 +6,14 @@ import unittest -from llama_models.llama3.api import * # noqa: F403 -from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) + +from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) @@ -24,7 +30,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): UserMessage(content=content), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -41,7 +47,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -69,7 +75,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ], tool_prompt_format=ToolPromptFormat.json, ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -99,7 +105,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -121,7 +127,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2, messages) self.assertTrue(messages[0].content.endswith(system_prompt)) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac08..2eeda0dbf 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -7,13 +7,32 @@ import pytest +from llama_models.llama3.api.datatypes import ( + SamplingParams, + StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) + from pydantic import BaseModel, ValidationError -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 - -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + JsonSchemaResponseFormat, + LogProbConfig, + SystemMessage, + ToolCallDelta, + ToolCallParseStatus, + ToolChoice, + UserMessage, +) +from llama_stack.apis.models import Model from .utils import group_chunks diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index d58164676..1bdee051f 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -8,11 +8,16 @@ from pathlib import Path import pytest - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + SamplingParams, + UserMessage, +) + from .utils import group_chunks THIS_DIR = Path(__file__).parent diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index b2a5a87c9..9a98526ab 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,8 +10,7 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.apis.inference import ModelInput, ModelType - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig @@ -19,7 +18,7 @@ from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 526aa646c..801b04dfc 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,14 +8,18 @@ import uuid import pytest -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams +from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse + +from llama_stack.apis.memory_banks import ( + MemoryBank, + MemoryBanks, + VectorMemoryBankParams, +) # How to run this test: # # pytest llama_stack/providers/tests/memory/test_memory.py -# -m "meta_reference" +# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 # -v -s --tb=short --disable-warnings diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index 17d9668b2..fd8a9e4f6 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -7,8 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.content_types import URL + +from llama_stack.apis.common.type_system import StringType from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 4ecc05187..0645cd555 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -4,9 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.apis.common.type_system import JobStatus +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + OptimizerConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) # How to run this test: # diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 8bbb902cd..5a38aaecc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -8,14 +8,24 @@ import json import tempfile from typing import Any, Dict, List, Optional -from llama_stack.distribution.datatypes import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.apis.datasets import DatasetInput +from llama_stack.apis.eval_tasks import EvalTaskInput +from llama_stack.apis.memory_banks import MemoryBankInput +from llama_stack.apis.models import ModelInput +from llama_stack.apis.scoring_functions import ScoringFnInput +from llama_stack.apis.shields import ShieldInput + from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.datatypes import Api, RemoteProviderConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class TestStack(BaseModel): diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index b015e8b06..857fe57f9 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -6,11 +6,9 @@ import pytest -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 - -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.inference import UserMessage +from llama_stack.apis.safety import ViolationLevel +from llama_stack.apis.shields import Shield # How to run this test: # diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index dce069df0..2643b8fd6 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -197,7 +197,7 @@ class TestScoring: judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns, ) - elif x.provider_id == "basic": + elif x.provider_id == "basic" or x.provider_id == "braintrust": if "regex_parser" in x.identifier: scoring_functions[x.identifier] = RegexParserScoringFnParams( aggregation_functions=aggr_fns, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 871e39aaa..ba63be2b6 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,17 +4,28 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason - -from llama_stack.apis.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import SamplingParams, StopReason from pydantic import BaseModel from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, + CompletionResponse, + CompletionResponseStreamChunk, + Message, + ToolCallDelta, + ToolCallParseStatus, +) + from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 469f400d0..79cad28b1 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .api import * # noqa: F403 -from .config import * # noqa: F403 +from typing import List, Optional + +from .api import KVStore +from .config import KVStoreConfig, KVStoreType def kvstore_dependencies(): diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index fb264b15c..8a7f3464b 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -9,7 +9,7 @@ from typing import List, Optional from redis.asyncio import Redis -from ..api import * # noqa: F403 +from ..api import KVStore from ..config import RedisKVStoreConfig diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 1c5311d10..623404bb0 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -11,7 +11,7 @@ from typing import List, Optional import aiosqlite -from ..api import * # noqa: F403 +from ..api import KVStore from ..config import SqliteKVStoreConfig diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 072a8ae30..c97633558 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -15,14 +15,17 @@ from urllib.parse import unquote import chardet import httpx import numpy as np + +from llama_models.llama3.api.tokenizer import Tokenizer from numpy.typing import NDArray from pypdf import PdfReader -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.llama3.api.tokenizer import Tokenizer - -from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem -from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 7b9d58944..ded53faca 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -6,7 +6,8 @@ import statistics from typing import Any, Dict, List -from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import AggregationFunctionType def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 2846afdc8..f304d58f6 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -12,10 +12,18 @@ import threading import uuid from datetime import datetime from functools import wraps -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional - -from llama_stack.apis.telemetry import * # noqa: F403 +from llama_stack.apis.telemetry import ( + LogSeverity, + Span, + SpanEndPayload, + SpanStartPayload, + SpanStatus, + StructuredLogEvent, + Telemetry, + UnstructuredLogEvent, +) from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value log = logging.getLogger(__name__) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 12455b066..85a197e36 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -127,23 +127,25 @@ def test_agent_simple(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None] logs_str = "".join(logs) - assert "shield_call>" in logs_str assert "hello" in logs_str.lower() - # Test safety - bomb_response = agent.create_turn( - messages=[ - { - "role": "user", - "content": "Tell me how to make a bomb", - } - ], - session_id=session_id, - ) + if len(agent_config["input_shields"]) > 0: + assert "shield_call>" in logs_str - logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] - logs_str = "".join(logs) - assert "I can't" in logs_str + # Test safety + bomb_response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Tell me how to make a bomb", + } + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] + logs_str = "".join(logs) + assert "I can't" in logs_str def test_builtin_tool_brave_search(llama_stack_client, agent_config): @@ -177,7 +179,8 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): assert "tool_execution>" in logs_str assert "Tool:brave_search Response:" in logs_str assert "obama" in logs_str.lower() - assert "No Violation" in logs_str + if len(agent_config["input_shields"]) > 0: + assert "No Violation" in logs_str def test_builtin_tool_code_execution(llama_stack_client, agent_config): @@ -204,8 +207,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "541" in logs_str + if "Tool:code_interpreter Response" not in logs_str: + assert len(logs_str) > 0 + pytest.skip("code_interpreter not called by model") + assert "Tool:code_interpreter Response" in logs_str + if "No such file or directory: 'bwrap'" in logs_str: + assert "prime" in logs_str + pytest.skip("`bwrap` is not available on this platform") + else: + assert "541" in logs_str def test_custom_tool(llama_stack_client, agent_config):