diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index 6b5bd53bf..e2ba5e22e 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -67,7 +67,7 @@ "from termcolor import cprint\n", "\n", "from llama_stack.distribution.datatypes import RemoteProviderConfig\n", - "from llama_stack.apis.safety import * # noqa: F403\n", + "from llama_stack.apis.safety import Safety\n", "from llama_stack_client import LlamaStackClient\n", "\n", "\n", @@ -127,7 +127,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5fd90ae7a..5748b4e41 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,18 +18,30 @@ from typing import ( Union, ) +from llama_models.llama3.api.datatypes import ToolParamDefinition + from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.common.deployment_types import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig +from llama_stack.apis.inference import ( + CompletionMessage, + SamplingParams, + ToolCall, + ToolCallDelta, + ToolChoice, + ToolPromptFormat, + ToolResponse, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.safety import SafetyViolation + +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 4c379999e..40a69d19c 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -6,13 +6,14 @@ from typing import Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.tool_utils import ToolUtils - from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType +from llama_stack.apis.inference import ToolResponseMessage + class LogEvent: def __init__( diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 358cf3c35..f7b8b4387 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import ( + CompletionMessage, + InterleavedContent, + LogProbConfig, + Message, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) @json_schema_type diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 22acc3211..983e0e4ea 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasets import Dataset @json_schema_type diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 2e0ce1fbc..2592bca37 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union + +from llama_models.llama3.api.datatypes import BaseModel, Field +from llama_models.schema_utils import json_schema_type, webmethod from typing_extensions import Annotated -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.schema_utils import json_schema_type, webmethod -from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.apis.inference import SamplingParams, SystemMessage +from llama_stack.apis.scoring import ScoringResult +from llama_stack.apis.scoring_functions import ScoringFnParams @json_schema_type diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 38da74128..d6794d488 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -7,8 +7,7 @@ import pytest -from llama_models.llama3.api import SamplingParams, URL - +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( @@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import ( BenchmarkEvalTaskConfig, ModelCandidate, ) +from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index dce069df0..2643b8fd6 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -197,7 +197,7 @@ class TestScoring: judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns, ) - elif x.provider_id == "basic": + elif x.provider_id == "basic" or x.provider_id == "braintrust": if "regex_parser" in x.identifier: scoring_functions[x.identifier] = RegexParserScoringFnParams( aggregation_functions=aggr_fns,