remove imports 1/n

This commit is contained in:
Xi Yan 2024-12-26 15:19:06 -08:00
parent 28ce511986
commit da26d22f90
8 changed files with 43 additions and 22 deletions

View file

@ -67,7 +67,7 @@
"from termcolor import cprint\n",
"\n",
"from llama_stack.distribution.datatypes import RemoteProviderConfig\n",
"from llama_stack.apis.safety import * # noqa: F403\n",
"from llama_stack.apis.safety import Safety\n",
"from llama_stack_client import LlamaStackClient\n",
"\n",
"\n",
@ -127,7 +127,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.11.10"
}
},
"nbformat": 4,

View file

@ -18,18 +18,30 @@ from typing import (
Union,
)
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type

View file

@ -6,13 +6,14 @@
from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.inference import ToolResponseMessage
class LogEvent:
def __init__(

View file

@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
CompletionMessage,
InterleavedContent,
LogProbConfig,
Message,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
@json_schema_type

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasets import Dataset
@json_schema_type

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.llama3.api.datatypes import BaseModel, Field
from llama_models.schema_utils import json_schema_type, webmethod
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@json_schema_type

View file

@ -7,8 +7,7 @@
import pytest
from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import (
@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import (
BenchmarkEvalTaskConfig,
ModelCandidate,
)
from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset

View file

@ -197,7 +197,7 @@ class TestScoring:
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif x.provider_id == "basic":
elif x.provider_id == "basic" or x.provider_id == "braintrust":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns,