From 27da763af9ac0455930f3c18a77d51f787851188 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 26 Dec 2024 18:30:42 -0800 Subject: [PATCH] more fixes --- .../inline/agents/meta_reference/agents.py | 17 ++++++++++++++--- .../inline/agents/meta_reference/persistence.py | 4 +++- .../meta_reference/rag/context_retriever.py | 4 +--- .../inline/agents/meta_reference/safety.py | 4 +++- .../agents/meta_reference/tools/safety.py | 2 +- .../inline/datasetio/localfs/config.py | 2 +- .../inline/datasetio/localfs/datasetio.py | 13 +++++++------ .../inline/eval/meta_reference/eval.py | 13 +++++++++---- .../inline/inference/meta_reference/config.py | 5 ++--- .../inference/meta_reference/generation.py | 6 +++--- 10 files changed, 44 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index dec5ec960..93bfab5f4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -9,15 +9,26 @@ import logging import shutil import tempfile import uuid -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from termcolor import colored -from llama_stack.apis.inference import Inference +from llama_stack.apis.agents import ( + AgentConfig, + AgentCreateResponse, + Agents, + AgentSessionCreateResponse, + AgentStepResponse, + AgentTurnCreateRequest, + Attachment, + Session, + Turn, +) + +from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety -from llama_stack.apis.agents import * # noqa: F403 from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 1c99e3d75..a4b1af616 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,9 +10,11 @@ import uuid from datetime import datetime from typing import List, Optional -from llama_stack.apis.agents import * # noqa: F403 + from pydantic import BaseModel +from llama_stack.apis.agents import Turn + from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index 7b5c8b4b0..74eb91c53 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -7,8 +7,6 @@ from typing import List from jinja2 import Template -from llama_models.llama3.api import * # noqa: F403 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, @@ -16,7 +14,7 @@ from llama_stack.apis.agents import ( MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import Message, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 8fca4d310..90d193f90 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -9,7 +9,9 @@ import logging from typing import List -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py index 1ffc99edd..a34649756 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py @@ -7,7 +7,7 @@ from typing import List from llama_stack.apis.inference import Message -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.safety import Safety from ..safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/inline/datasetio/localfs/config.py b/llama_stack/providers/inline/datasetio/localfs/config.py index 58d563c99..1b89df63b 100644 --- a/llama_stack/providers/inline/datasetio/localfs/config.py +++ b/llama_stack/providers/inline/datasetio/localfs/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.datasetio import * # noqa: F401, F403 +from pydantic import BaseModel class LocalFSDatasetIOConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 736e5d8b9..442053fb3 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,18 +3,19 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional - -import pandas -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.datasetio import * # noqa: F403 import base64 import os from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Any, Dict, List, Optional from urllib.parse import urlparse +import pandas + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasets import Dataset + from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index e1c2cc804..00630132e 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -5,13 +5,15 @@ # the root directory of this source tree. from enum import Enum from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 + from tqdm import tqdm -from .....apis.common.job_types import Job -from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.agents import Agents +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask @@ -20,6 +22,9 @@ from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus + from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 33af33fcd..2c46ef596 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -6,11 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import * # noqa: F403 - -from llama_stack.apis.inference import * # noqa: F401, F403 from pydantic import BaseModel, field_validator +from llama_stack.apis.inference import QuantizationConfig + from llama_stack.providers.utils.inference import supported_inference_models diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index c89183cb7..9067fb043 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -32,11 +32,11 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model -from pydantic import BaseModel - -from llama_stack.apis.inference import * # noqa: F403 from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData +from pydantic import BaseModel + +from llama_stack.apis.inference import ResponseFormat, ResponseFormatType from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import (