From 4936794de1928539d18db40647c01db55895e5dc Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 17 Dec 2024 10:22:14 -0800 Subject: [PATCH] Move InterleavedContent to api/common/content_types.py --- llama_stack/apis/agents/agents.py | 3 +- llama_stack/apis/common/content_types.py | 49 +++++++++++++++++++ llama_stack/apis/common/deployment_types.py | 10 ---- llama_stack/apis/datasets/datasets.py | 2 +- llama_stack/apis/inference/inference.py | 32 +----------- llama_stack/apis/memory/memory.py | 2 +- .../distribution/routers/routing_tables.py | 2 +- .../inline/safety/llama_guard/llama_guard.py | 1 + .../remote/inference/ollama/ollama.py | 2 +- .../tests/inference/test_vision_inference.py | 1 + .../providers/tests/post_training/fixtures.py | 2 +- .../providers/utils/datasetio/url_utils.py | 2 +- .../utils/inference/openai_compat.py | 2 + .../utils/inference/prompt_adapter.py | 12 +++-- .../providers/utils/memory/file_utils.py | 2 +- .../providers/utils/memory/vector_store.py | 2 +- 16 files changed, 71 insertions(+), 55 deletions(-) create mode 100644 llama_stack/apis/common/content_types.py diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 51b93b621..5fd90ae7a 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -29,8 +29,7 @@ from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.common.deployment_types import URL -from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, URL @json_schema_type diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py new file mode 100644 index 000000000..1403dd782 --- /dev/null +++ b/llama_stack/apis/common/content_types.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Annotated, List, Literal, Union + +from llama_models.schema_utils import json_schema_type, register_schema + +from pydantic import BaseModel, Field + + +@json_schema_type( + schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"} +) +class URL(BaseModel): + uri: str + + def __str__(self) -> str: + return self.uri + + +@json_schema_type +class ImageContentItem(BaseModel): + type: Literal["image"] = "image" + data: Union[bytes, URL] + + +@json_schema_type +class TextContentItem(BaseModel): + type: Literal["text"] = "text" + text: str + + +# other modalities can be added here +InterleavedContentItem = register_schema( + Annotated[ + Union[ImageContentItem, TextContentItem], + Field(discriminator="type"), + ], + name="InterleavedContentItem", +) + +# accept a single "str" as a special case since it is common +InterleavedContent = register_schema( + Union[str, InterleavedContentItem, List[InterleavedContentItem]], + name="InterleavedContent", +) diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 35a53031c..67096ac52 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -12,16 +12,6 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel -@json_schema_type( - schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"} -) -class URL(BaseModel): - uri: str - - def __str__(self) -> str: - return self.uri - - @json_schema_type class RestAPIMethod(Enum): GET = "GET" diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 2dbf9bd42..7afc0f8fd 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b465d8509..c481d04d7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,12 +25,12 @@ from llama_models.llama3.api.datatypes import ( ToolPromptFormat, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.apis.models import * # noqa: F403 @@ -69,34 +69,6 @@ QuantizationConfig = Annotated[ ] -@json_schema_type -class ImageContentItem(BaseModel): - type: Literal["image"] = "image" - data: Union[bytes, URL] - - -@json_schema_type -class TextContentItem(BaseModel): - type: Literal["text"] = "text" - text: str - - -# other modalities can be added here -InterleavedContentItem = register_schema( - Annotated[ - Union[ImageContentItem, TextContentItem], - Field(discriminator="type"), - ], - name="InterleavedContentItem", -) - -# accept a single "str" as a special case since it is common -InterleavedContent = register_schema( - Union[str, InterleavedContentItem, List[InterleavedContentItem]], - name="InterleavedContent", -) - - @json_schema_type class UserMessage(BaseModel): role: Literal["user"] = "user" diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 85d637ca7..8096a107a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.memory_banks import MemoryBank from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d9b5e1319..ecf47a054 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.distribution.store import DistributionRegistry diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index ce749ec5b..c243427d3 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ShieldsProtocolPrivate diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f02f3682d..2f51f1299 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -21,8 +21,8 @@ from llama_stack.providers.utils.inference.model_registry import ( ) from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.providers.datatypes import ModelsProtocolPrivate - from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 967e124fe..d29ace491 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -11,6 +11,7 @@ import pytest from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from .utils import group_chunks diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index eb7f3a66b..17d9668b2 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index 4e99a3daf..da1e84d4d 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -10,7 +10,7 @@ from urllib.parse import unquote import pandas -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL from llama_stack.providers.utils.memory.vector_store import parse_data_url diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 0f1e6894e..871e39aaa 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -13,6 +13,8 @@ from llama_models.llama3.api.datatypes import StopReason from llama_stack.apis.inference import * # noqa: F403 from pydantic import BaseModel +from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem + from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 928b089e0..4f51467c2 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -34,19 +34,21 @@ from llama_models.llama3.prompt_templates import ( from llama_models.sku_list import resolve_model from PIL import Image as PIL_Image -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + InterleavedContentItem, + TextContentItem, + URL, +) from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, - ImageContentItem, - InterleavedContent, - InterleavedContentItem, Message, ResponseFormat, ResponseFormatType, SystemMessage, - TextContentItem, ToolChoice, UserMessage, ) diff --git a/llama_stack/providers/utils/memory/file_utils.py b/llama_stack/providers/utils/memory/file_utils.py index 9ea3397fd..4c40056f3 100644 --- a/llama_stack/providers/utils/memory/file_utils.py +++ b/llama_stack/providers/utils/memory/file_utils.py @@ -8,7 +8,7 @@ import base64 import mimetypes import os -from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.common.content_types import URL def data_url_from_file(file_path: str) -> URL: diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index cfe5c2816..072a8ae30 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -21,7 +21,7 @@ from pypdf import PdfReader from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import InterleavedContent, TextContentItem +from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.providers.datatypes import Api