Move InterleavedContent to api/common/content_types.py

This commit is contained in:
Ashwin Bharambe 2024-12-17 10:22:14 -08:00
parent a30aaaa2e5
commit 4936794de1
16 changed files with 71 additions and 55 deletions

View file

@ -29,8 +29,7 @@ from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.inference import InterleavedContent
@json_schema_type @json_schema_type

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, List, Literal, Union
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
class URL(BaseModel):
uri: str
def __str__(self) -> str:
return self.uri
@json_schema_type
class ImageContentItem(BaseModel):
type: Literal["image"] = "image"
data: Union[bytes, URL]
@json_schema_type
class TextContentItem(BaseModel):
type: Literal["text"] = "text"
text: str
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)

View file

@ -12,16 +12,6 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
class URL(BaseModel):
uri: str
def __str__(self) -> str:
return self.uri
@json_schema_type @json_schema_type
class RestAPIMethod(Enum): class RestAPIMethod(Enum):
GET = "GET" GET = "GET"

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType

View file

@ -25,12 +25,12 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
@ -69,34 +69,6 @@ QuantizationConfig = Annotated[
] ]
@json_schema_type
class ImageContentItem(BaseModel):
type: Literal["image"] = "image"
data: Union[bytes, URL]
@json_schema_type
class TextContentItem(BaseModel):
type: Literal["text"] = "text"
text: str
# other modalities can be added here
InterleavedContentItem = register_schema(
Annotated[
Union[ImageContentItem, TextContentItem],
Field(discriminator="type"),
],
name="InterleavedContentItem",
)
# accept a single "str" as a special case since it is common
InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
@json_schema_type @json_schema_type
class UserMessage(BaseModel): class UserMessage(BaseModel):
role: Literal["user"] = "user" role: Literal["user"] = "user"

View file

@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBank from llama_stack.apis.memory_banks import MemoryBank
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol

View file

@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry

View file

@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate

View file

@ -21,8 +21,8 @@ from llama_stack.providers.utils.inference.model_registry import (
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,

View file

@ -11,6 +11,7 @@ import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from .utils import group_chunks from .utils import group_chunks

View file

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput

View file

@ -10,7 +10,7 @@ from urllib.parse import unquote
import pandas import pandas
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url

View file

@ -13,6 +13,8 @@ from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )

View file

@ -34,19 +34,21 @@ from llama_models.llama3.prompt_templates import (
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
URL,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SystemMessage, SystemMessage,
TextContentItem,
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )

View file

@ -8,7 +8,7 @@ import base64
import mimetypes import mimetypes
import os import os
from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.content_types import URL
def data_url_from_file(file_path: str) -> URL: def data_url_from_file(file_path: str) -> URL:

View file

@ -21,7 +21,7 @@ from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import InterleavedContent, TextContentItem from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api