[remove import *] clean up import *'s (#689)

# What does this PR do?

- as title, cleaning up `import *`'s
- upgrade tests to make them more robust to bad model outputs
- remove import *'s in llama_stack/apis/* (skip __init__ modules)
<img width="465" alt="image"
src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2"
/>

- run `sh run_openapi_generator.sh`, no types gets affected

## Test Plan

### Providers Tests

**agents**
```
pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8
```

**inference**
```bash
# meta-reference
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

# together
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py 
```

**safety**
```
pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B
```

**memory**
```
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
```

**scoring**
```
pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
```


**datasetio**
```
pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py
pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py
```


**eval**
```
pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py
```

### Client-SDK Tests
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk
```

### llama-stack-apps
```
PORT=5000
LOCALHOST=localhost

python -m examples.agents.hello $LOCALHOST $PORT
python -m examples.agents.inflation $LOCALHOST $PORT
python -m examples.agents.podcast_transcript $LOCALHOST $PORT
python -m examples.agents.rag_as_attachments $LOCALHOST $PORT
python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT
python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT
python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT

# Vision model
python -m examples.interior_design_assistant.app
python -m examples.agent_store.app $LOCALHOST $PORT
```

### CLI
```
which llama
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
llama model list
llama stack list-apis
llama stack list-providers inference

llama stack build --template ollama --image-type conda
```

### Distributions Tests
**ollama**
```
llama stack build --template ollama --image-type conda
ollama run llama3.2:1b-instruct-fp16
llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
```

**fireworks**
```
llama stack build --template fireworks --image-type conda
llama stack run ./llama_stack/templates/fireworks/run.yaml
```

**together**
```
llama stack build --template together --image-type conda
llama stack run ./llama_stack/templates/together/run.yaml
```

**tgi**
```
llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2024-12-27 15:45:44 -08:00 committed by GitHub
parent 70db039ff4
commit 3c72c034e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
99 changed files with 907 additions and 359 deletions

View file

@ -67,7 +67,7 @@
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"from llama_stack.distribution.datatypes import RemoteProviderConfig\n", "from llama_stack.distribution.datatypes import RemoteProviderConfig\n",
"from llama_stack.apis.safety import * # noqa: F403\n", "from llama_stack.apis.safety import Safety\n",
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"\n", "\n",
"\n", "\n",
@ -127,7 +127,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.15" "version": "3.11.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View file

@ -18,18 +18,30 @@ from typing import (
Union, Union,
) )
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type @json_schema_type

View file

@ -6,13 +6,14 @@
from typing import Optional from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.inference import ToolResponseMessage
class LogEvent: class LogEvent:
def __init__( def __init__(

View file

@ -10,8 +10,16 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import (
from llama_stack.apis.inference import * # noqa: F403 CompletionMessage,
InterleavedContent,
LogProbConfig,
Message,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
@json_schema_type @json_schema_type

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import Dataset
@json_schema_type @json_schema_type

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.llama3.api.datatypes import BaseModel, Field
from llama_models.schema_utils import json_schema_type, webmethod
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@json_schema_type @json_schema_type

View file

@ -7,7 +7,9 @@
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any,
AsyncIterator, AsyncIterator,
Dict,
List, List,
Literal, Literal,
Optional, Optional,
@ -32,8 +34,9 @@ from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):

View file

@ -7,17 +7,17 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type @json_schema_type

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.scoring_functions import * # noqa: F403
# mapping of metric to value # mapping of metric to value

View file

@ -6,13 +6,12 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message

View file

@ -6,11 +6,12 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""

View file

@ -3,21 +3,28 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
import os import os
import shutil import shutil
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Optional
import pkg_resources import pkg_resources
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"

View file

@ -6,21 +6,22 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import List
from pathlib import Path
from typing import Dict, List
import pkg_resources import pkg_resources
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -6,10 +6,14 @@
import logging import logging
import textwrap import textwrap
from typing import Any from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
@ -17,10 +21,7 @@ from llama_stack.distribution.distribution import (
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -4,24 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List, Optional, Union from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_BUILD_CONFIG_VERSION = "2"

View file

@ -5,12 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class DistributionInspectConfig(BaseModel): class DistributionInspectConfig(BaseModel):

View file

@ -6,14 +6,10 @@
import importlib import importlib
import inspect import inspect
from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import logging import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -32,10 +28,32 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
RoutingTableProviderSpec,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,
EvalTasksProtocolPrivate,
InlineProviderSpec,
MemoryBanksProtocolPrivate,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Dict
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from .routing_tables import ( from .routing_tables import (
DatasetsRoutingTable, DatasetsRoutingTable,

View file

@ -6,16 +6,40 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.datasetio.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import * # noqa: F403 from llama_stack.apis.eval import (
from llama_stack.apis.inference import * # noqa: F403 AppEvalTaskConfig,
from llama_stack.apis.memory import * # noqa: F403 Eval,
EvalTaskConfig,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.models import ModelType
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.tools import * # noqa: F403 from llama_stack.apis.scoring import (
from llama_stack.distribution.datatypes import RoutingTable ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
class MemoryRouter(Memory): class MemoryRouter(Memory):
@ -330,7 +354,6 @@ class EvalRouter(Eval):
task_config=task_config, task_config=task_config,
) )
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows( async def evaluate_rows(
self, self,
task_id: str, task_id: str,

View file

@ -6,19 +6,42 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import parse_obj_as from pydantic import parse_obj_as
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import Dataset, Datasets
from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.apis.eval_tasks import EvalTask, EvalTasks
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import (
from llama_stack.apis.models import * # noqa: F403 BankParams,
from llama_stack.apis.shields import * # noqa: F403 MemoryBank,
from llama_stack.apis.tools import * # noqa: F403 MemoryBanks,
from llama_stack.distribution.datatypes import * # noqa: F403 MemoryBankType,
)
from llama_stack.apis.models import Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:

View file

@ -28,14 +28,9 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.distribution import builtin_automatically_routed_apis
end_trace,
setup_logger,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
@ -43,11 +38,19 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter, TelemetryAdapter,
) )
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints

View file

@ -8,32 +8,31 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, Optional
import pkg_resources import pkg_resources
import yaml import yaml
from termcolor import colored from termcolor import colored
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.agents import Agents
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.eval import Eval
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.eval import * # noqa: F403 from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inspect import Inspect
from llama_stack.apis.batch_inference import * # noqa: F403 from llama_stack.apis.memory import Memory
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.models import Models
from llama_stack.apis.post_training import * # noqa: F403 from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.synthetic_data_generation import * # noqa: F403 from llama_stack.apis.safety import Safety
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.scoring import Scoring
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.shields import Shields
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry

View file

@ -13,11 +13,8 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import ( from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
KVStore, from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):

View file

@ -8,11 +8,14 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.store import * # noqa F403
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model
from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa F403
@pytest.fixture @pytest.fixture

View file

@ -13,19 +13,64 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import (
from llama_stack.apis.inference import * # noqa: F403 AgentConfig,
from llama_stack.apis.memory import * # noqa: F403 AgentTool,
from llama_stack.apis.memory_banks import * # noqa: F403 AgentTurnCreateRequest,
from llama_stack.apis.safety import * # noqa: F403 AgentTurnResponseEvent,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
Inference,
Message,
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content

View file

@ -9,15 +9,26 @@ import logging
import shutil import shutil
import tempfile import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored from termcolor import colored
from llama_stack.apis.inference import Inference from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTurnCreateRequest,
Attachment,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl

View file

@ -10,9 +10,11 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,8 +7,6 @@
from typing import List from typing import List
from jinja2 import Template from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )

View file

@ -9,7 +9,9 @@ import logging
from typing import List from typing import List
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union
import pytest import pytest
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.agents import (
from llama_stack.apis.memory import * # noqa: F403 AgentConfig,
from llama_stack.apis.safety import * # noqa: F403 AgentTurnCreateRequest,
from llama_stack.apis.agents import * # noqa: F403 AgentTurnResponseTurnCompletePayload,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import ( from ..agents import (
AGENT_INSTANCES_BY_ID, AGENT_INSTANCES_BY_ID,

View file

@ -7,7 +7,7 @@
from typing import List from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin from ..safety import ShieldRunnerMixin
from .builtin import BaseTool from .builtin import BaseTool

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403 from pydantic import BaseModel
class LocalFSDatasetIOConfig(BaseModel): ... class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -3,18 +3,19 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64 import base64
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url

View file

@ -5,13 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
@ -20,6 +22,9 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .config import ( from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,10 +7,10 @@
import logging import logging
import os import os
import uuid import uuid
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,

View file

@ -16,11 +16,14 @@ import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (

View file

@ -14,11 +14,10 @@ from enum import Enum
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
import torch import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer

View file

@ -3,11 +3,26 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import webmethod
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice, LoraFinetuningSingleDevice,
) )

View file

@ -14,27 +14,33 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
@ -47,6 +53,8 @@ from torchtune.modules.peft import (
validate_missing_and_unexpected_for_lora, validate_missing_and_unexpected_for_lora,
) )
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -7,8 +7,14 @@
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )

View file

@ -9,10 +9,24 @@ import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.datatypes import CoreModelId
from llama_stack.apis.inference import * # noqa: F403 from llama_models.llama3.api.datatypes import Role
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
Inference,
Message,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate

View file

@ -11,11 +11,16 @@ import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import Message
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import (
from llama_stack.apis.safety import * # noqa: F403 RunShieldResponse,
from llama_models.llama3.api.datatypes import * # noqa: F403 Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -3,14 +3,17 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring import (
from llama_stack.apis.common.type_system import * # noqa: F403 ScoreBatchResponse,
from llama_stack.apis.datasetio import * # noqa: F403 ScoreResponse,
from llama_stack.apis.datasets import * # noqa: F403 Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import BasicScoringConfig from .config import BasicScoringConfig

View file

@ -3,20 +3,23 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
import os import os
from typing import Any, Dict, List, Optional
from autoevals.llm import Factuality from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness from autoevals.ragas import AnswerCorrectness
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate

View file

@ -3,7 +3,9 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403 from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class BraintrustScoringConfig(BaseModel): class BraintrustScoringConfig(BaseModel):

View file

@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
QueryCondition,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
SpanWithStatus,
StructuredLogEvent,
Telemetry,
Trace,
UnstructuredLogEvent,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = { _GLOBAL_STORAGE = {

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.telemetry import * # noqa: F403
class SampleTelemetryImpl(Telemetry): class SampleTelemetryImpl(Telemetry):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -6,7 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.kvstore import kvstore_dependencies from llama_stack.providers.utils.kvstore import kvstore_dependencies

View file

@ -6,7 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,8 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
META_REFERENCE_DEPS = [ META_REFERENCE_DEPS = [
"accelerate", "accelerate",

View file

@ -6,8 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
EMBEDDING_DEPS = [ EMBEDDING_DEPS = [
"blobfile", "blobfile",

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec, AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,7 +6,13 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec, AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.agents import Agents
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.agents import * # noqa: F403
class SampleAgentsImpl(Agents): class SampleAgentsImpl(Agents):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -5,11 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import * # noqa: F403
import json import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -13,6 +13,24 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -29,11 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
MODEL_ALIASES = [ MODEL_ALIASES = [
build_model_alias( build_model_alias(

View file

@ -4,17 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
from llama_models.datatypes import CoreModelId ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -14,7 +14,20 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,

View file

@ -11,7 +11,24 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -14,15 +14,33 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
build_model_alias_with_just_provider_model_id, build_model_alias_with_just_provider_model_id,
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403
class SampleInferenceImpl(Inference): class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -13,10 +13,25 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -14,7 +14,22 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -13,7 +13,25 @@ from llama_models.sku_list import all_registered_models
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (

View file

@ -12,8 +12,14 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import List, Tuple from typing import Any, Dict, List, Optional, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
@ -14,8 +14,14 @@ from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (

View file

@ -6,16 +6,21 @@
import logging import logging
import uuid import uuid
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from numpy.typing import NDArray from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403
class SampleMemoryImpl(Memory): class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -14,8 +14,14 @@ from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (

View file

@ -9,8 +9,15 @@ import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.safety import * # noqa from llama_stack.apis.inference import Message
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield
from .config import SampleConfig from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403
class SampleSafetyImpl(Safety): class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config

View file

@ -5,11 +5,31 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from typing import Dict, List
import pytest import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import (
from llama_stack.providers.datatypes import * # noqa: F403 AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolChoice,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api
# How to run this test: # How to run this test:
# #

View file

@ -6,9 +6,9 @@
import pytest import pytest
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.apis.inference import SamplingParams, UserMessage
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .fixtures import pick_inference_model from .fixtures import pick_inference_model

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64 import base64
import mimetypes import mimetypes
import os
from pathlib import Path from pathlib import Path
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasets import Datasets
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py # pytest llama_stack/providers/tests/datasetio/test_datasetio.py

View file

@ -7,8 +7,7 @@
import pytest import pytest
from llama_models.llama3.api import SamplingParams, URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.eval.eval import ( from llama_stack.apis.eval.eval import (
@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import (
BenchmarkEvalTaskConfig, BenchmarkEvalTaskConfig,
ModelCandidate, ModelCandidate,
) )
from llama_stack.apis.inference import SamplingParams
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset

View file

@ -6,8 +6,14 @@
import unittest import unittest
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api.datatypes import (
from llama_stack.apis.inference.inference import * # noqa: F403 BuiltinTool,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
@ -24,7 +30,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content), UserMessage(content=content),
], ],
) )
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -41,7 +47,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search), ToolDefinition(tool_name=BuiltinTool.brave_search),
], ],
) )
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -69,7 +75,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
tool_prompt_format=ToolPromptFormat.json, tool_prompt_format=ToolPromptFormat.json,
) )
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -99,7 +105,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
), ),
], ],
) )
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) self.assertTrue("Environment: ipython" in messages[0].content)
@ -121,7 +127,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter), ToolDefinition(tool_name=BuiltinTool.code_interpreter),
], ],
) )
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages) self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt)) self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -7,13 +7,32 @@
import pytest import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import (
from llama_stack.apis.inference import * # noqa: F403 ChatCompletionResponse,
ChatCompletionResponseEventType,
from llama_stack.distribution.datatypes import * # noqa: F403 ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks from .utils import group_chunks

View file

@ -8,11 +8,16 @@ from pathlib import Path
import pytest import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
SamplingParams,
UserMessage,
)
from .utils import group_chunks from .utils import group_chunks
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent

View file

@ -10,8 +10,7 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
@ -19,7 +18,7 @@ from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail

View file

@ -8,14 +8,18 @@ import uuid
import pytest import pytest
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams from llama_stack.apis.memory_banks import (
MemoryBank,
MemoryBanks,
VectorMemoryBankParams,
)
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/memory/test_memory.py # pytest llama_stack/providers/tests/memory/test_memory.py
# -m "meta_reference" # -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings

View file

@ -7,8 +7,9 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput

View file

@ -4,9 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403 from llama_stack.apis.common.type_system import JobStatus
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
# How to run this test: # How to run this test:
# #

View file

@ -8,14 +8,24 @@ import json
import tempfile import tempfile
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.distribution.datatypes import * # noqa: F403 from pydantic import BaseModel
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.memory_banks import MemoryBankInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.datatypes import Api, RemoteProviderConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestStack(BaseModel): class TestStack(BaseModel):

View file

@ -6,11 +6,9 @@
import pytest import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.apis.shields import Shield
# How to run this test: # How to run this test:
# #

View file

@ -197,7 +197,7 @@ class TestScoring:
judge_score_regexes=[r"Score: (\d+)"], judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,
) )
elif x.provider_id == "basic": elif x.provider_id == "basic" or x.provider_id == "braintrust":
if "regex_parser" in x.identifier: if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams( scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,

View file

@ -4,17 +4,28 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Message,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .api import * # noqa: F403 from typing import List, Optional
from .config import * # noqa: F403
from .api import KVStore
from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies(): def kvstore_dependencies():

View file

@ -9,7 +9,7 @@ from typing import List, Optional
from redis.asyncio import Redis from redis.asyncio import Redis
from ..api import * # noqa: F403 from ..api import KVStore
from ..config import RedisKVStoreConfig from ..config import RedisKVStoreConfig

View file

@ -11,7 +11,7 @@ from typing import List, Optional
import aiosqlite import aiosqlite
from ..api import * # noqa: F403 from ..api import KVStore
from ..config import SqliteKVStoreConfig from ..config import SqliteKVStoreConfig

View file

@ -15,14 +15,17 @@ from urllib.parse import unquote
import chardet import chardet
import httpx import httpx
import numpy as np import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray from numpy.typing import NDArray
from pypdf import PdfReader from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.content_types import (
from llama_models.llama3.api.tokenizer import Tokenizer InterleavedContent,
TextContentItem,
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem URL,
from llama_stack.apis.memory import * # noqa: F403 )
from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -6,7 +6,8 @@
import statistics import statistics
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:

View file

@ -12,10 +12,18 @@ import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
from llama_stack.apis.telemetry import (
from llama_stack.apis.telemetry import * # noqa: F403 LogSeverity,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -127,9 +127,11 @@ def test_agent_simple(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None] logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)
assert "shield_call>" in logs_str
assert "hello" in logs_str.lower() assert "hello" in logs_str.lower()
if len(agent_config["input_shields"]) > 0:
assert "shield_call>" in logs_str
# Test safety # Test safety
bomb_response = agent.create_turn( bomb_response = agent.create_turn(
messages=[ messages=[
@ -177,6 +179,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
assert "tool_execution>" in logs_str assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str assert "Tool:brave_search Response:" in logs_str
assert "obama" in logs_str.lower() assert "obama" in logs_str.lower()
if len(agent_config["input_shields"]) > 0:
assert "No Violation" in logs_str assert "No Violation" in logs_str
@ -204,8 +207,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None] logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)
assert "541" in logs_str if "Tool:code_interpreter Response" not in logs_str:
assert len(logs_str) > 0
pytest.skip("code_interpreter not called by model")
assert "Tool:code_interpreter Response" in logs_str assert "Tool:code_interpreter Response" in logs_str
if "No such file or directory: 'bwrap'" in logs_str:
assert "prime" in logs_str
pytest.skip("`bwrap` is not available on this platform")
else:
assert "541" in logs_str
def test_custom_tool(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config):