From e51154964f5d6c6b452a75616d5368d8cffc6323 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 15 Dec 2024 17:31:21 -0800 Subject: [PATCH] Tests pass with Ollama now --- llama_stack/apis/agents/agents.py | 14 ++++++---- .../apis/batch_inference/batch_inference.py | 4 +-- llama_stack/apis/common/training_types.py | 2 ++ llama_stack/apis/datasets/datasets.py | 4 +-- llama_stack/apis/eval/eval.py | 1 + llama_stack/apis/inference/inference.py | 2 +- llama_stack/apis/memory/memory.py | 14 +++++----- llama_stack/apis/safety/safety.py | 10 +++---- .../synthetic_data_generation.py | 1 + llama_stack/distribution/routers/routers.py | 6 ++-- .../distribution/routers/routing_tables.py | 5 +--- .../inline/agents/meta_reference/safety.py | 2 -- .../safety/code_scanner/code_scanner.py | 10 +++++-- .../remote/inference/cerebras/cerebras.py | 1 - .../remote/inference/databricks/databricks.py | 1 - .../remote/inference/fireworks/fireworks.py | 1 - .../remote/inference/ollama/ollama.py | 3 +- .../providers/remote/inference/tgi/tgi.py | 4 +-- .../remote/inference/together/together.py | 1 - .../providers/remote/inference/vllm/vllm.py | 1 - .../tests/inference/test_vision_inference.py | 28 +++++++++++++------ .../providers/tests/post_training/fixtures.py | 2 +- .../providers/utils/datasetio/url_utils.py | 2 +- .../utils/inference/embedding_mixin.py | 10 ++++--- .../utils/inference/openai_compat.py | 8 ++++-- .../utils/inference/prompt_adapter.py | 9 +++--- .../providers/utils/memory/file_utils.py | 2 +- 27 files changed, 83 insertions(+), 65 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 575f336af..51b93b621 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -29,11 +29,13 @@ from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.inference import InterleavedContent @json_schema_type class Attachment(BaseModel): - content: InterleavedTextMedia | URL + content: InterleavedContent | URL mime_type: str @@ -102,20 +104,20 @@ class _MemoryBankConfigCommon(BaseModel): class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + type: Literal["vector"] = "vector" class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value + type: Literal["keyvalue"] = "keyvalue" keys: List[str] # what keys to focus on class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value + type: Literal["keyword"] = "keyword" class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + type: Literal["graph"] = "graph" entities: List[str] # what entities to focus on @@ -230,7 +232,7 @@ class MemoryRetrievalStep(StepCommon): StepType.memory_retrieval.value ) memory_bank_ids: List[str] - inserted_context: InterleavedTextMedia + inserted_context: InterleavedContent Step = Annotated[ diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 4e15b28a6..358cf3c35 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -17,7 +17,7 @@ from llama_stack.apis.inference import * # noqa: F403 @json_schema_type class BatchCompletionRequest(BaseModel): model: str - content_batch: List[InterleavedTextMedia] + content_batch: List[InterleavedContent] sampling_params: Optional[SamplingParams] = SamplingParams() logprobs: Optional[LogProbConfig] = None @@ -53,7 +53,7 @@ class BatchInference(Protocol): async def batch_completion( self, model: str, - content_batch: List[InterleavedTextMedia], + content_batch: List[InterleavedContent], sampling_params: Optional[SamplingParams] = SamplingParams(), logprobs: Optional[LogProbConfig] = None, ) -> BatchCompletionResponse: ... diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index b4bd1b0c6..ed278553e 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -10,6 +10,8 @@ from typing import Optional from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.apis.common.deployment_types import URL + @json_schema_type class PostTrainingMetric(BaseModel): diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e1ac4af21..2dbf9bd42 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol -from llama_models.llama3.api.datatypes import URL - from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from llama_stack.apis.common.deployment_types import URL + from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index e52d4dab6..2e0ce1fbc 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -15,6 +15,7 @@ from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_stack.apis.inference import SamplingParams, SystemMessage @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 1255f8b76..2d7936cf7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -247,7 +247,7 @@ class CompletionResponseStreamChunk(BaseModel): @json_schema_type class BatchCompletionRequest(BaseModel): model: str - content_batch: List[InterleavedTextMedia] + content_batch: List[InterleavedContent] sampling_params: Optional[SamplingParams] = SamplingParams() response_format: Optional[ResponseFormat] = None logprobs: Optional[LogProbConfig] = None diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 2f3a94956..85d637ca7 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -8,27 +8,27 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod - from pydantic import BaseModel, Field -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.common.deployment_types import URL +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory_banks import MemoryBank from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @json_schema_type class MemoryBankDocument(BaseModel): document_id: str - content: InterleavedTextMedia | URL + content: InterleavedContent | URL mime_type: str | None = None metadata: Dict[str, Any] = Field(default_factory=dict) class Chunk(BaseModel): - content: InterleavedTextMedia + content: InterleavedContent token_count: int document_id: str @@ -62,6 +62,6 @@ class Memory(Protocol): async def query_documents( self, bank_id: str, - query: InterleavedTextMedia, + query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 26ae45ae7..dd24642b1 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -5,16 +5,16 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel +from pydantic import BaseModel, Field + +from llama_stack.apis.inference import Message +from llama_stack.apis.shields import Shield from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 - @json_schema_type class ViolationLevel(Enum): diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 717a0ec2f..4ffaa4d1e 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message class FilteringFunction(Enum): diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 16ae35357..586ebfae4 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -59,7 +59,7 @@ class MemoryRouter(Memory): async def query_documents( self, bank_id: str, - query: InterleavedTextMedia, + query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: return await self.routing_table.get_provider_impl(bank_id).query_documents( @@ -133,7 +133,7 @@ class InferenceRouter(Inference): async def completion( self, model_id: str, - content: InterleavedTextMedia, + content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, @@ -163,7 +163,7 @@ class InferenceRouter(Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 01edf4e5a..d9b5e1319 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -16,8 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 - -from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.distribution.store import DistributionRegistry @@ -30,7 +29,6 @@ def get_impl_api(p: Any) -> Api: # TODO: this should return the registered object for all APIs async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: - api = get_impl_api(p) assert obj.provider_id != "remote", "Remote provider should not be registered" @@ -76,7 +74,6 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects( objs: List[RoutableObjectWithProvider], provider_id: str, cls ) -> None: diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 3eca94fc5..8fca4d310 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -9,8 +9,6 @@ import logging from typing import List -from llama_models.llama3.api.datatypes import Message - from llama_stack.apis.safety import * # noqa: F403 log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 54a4d0b18..46b5e57da 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -7,13 +7,17 @@ import logging from typing import Any, Dict, List -from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.inference import Message +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) from .config import CodeScannerConfig -from llama_stack.apis.safety import * # noqa: F403 log = logging.getLogger(__name__) + ALLOWED_CODE_SCANNER_MODEL_IDS = [ "CodeScanner", "CodeShield", @@ -48,7 +52,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): from codeshield.cs import CodeShield - text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages]) + text = "\n".join([interleaved_content_as_str(m.content) for m in messages]) log.info(f"Running CodeScannerShield on {text[50:]}") result = await CodeShield.scan_code(text) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 65022f85e..235c4be7d 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -10,7 +10,6 @@ from cerebras.cloud.sdk import AsyncCerebras from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 0ebb625bc..4428560e0 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ef63abdb0..bb3ee67ec 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -10,7 +10,6 @@ from fireworks.client import Fireworks from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 573747536..f02f3682d 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -11,7 +11,6 @@ import httpx from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -90,7 +89,7 @@ model_aliases = [ CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias_with_just_provider_model_id( - "llama3.2-vision", + "llama3.2-vision:latest", CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 01981c62b..f82bb2c77 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def completion( self, model_id: str, - content: InterleavedTextMedia, + content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, @@ -267,7 +267,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, model_id: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 2b45a9954..b2e6e06ba 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from together import Together diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7c8e000f9..12392ea50 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,7 +8,6 @@ import logging from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 56fa4c075..967e124fe 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -7,7 +7,6 @@ from pathlib import Path import pytest -from PIL import Image as PIL_Image from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -17,6 +16,9 @@ from .utils import group_chunks THIS_DIR = Path(__file__).parent +with open(THIS_DIR / "pasta.jpeg", "rb") as f: + PASTA_IMAGE = f.read() + class TestVisionModelInference: @pytest.mark.asyncio @@ -24,12 +26,12 @@ class TestVisionModelInference: "image, expected_strings", [ ( - ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), + ImageContentItem(data=PASTA_IMAGE), ["spaghetti"], ), ( - ImageMedia( - image=URL( + ImageContentItem( + data=URL( uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" ) ), @@ -58,7 +60,12 @@ class TestVisionModelInference: model_id=inference_model, messages=[ UserMessage(content="You are a helpful assistant."), - UserMessage(content=[image, "Describe this image in two sentences."]), + UserMessage( + content=[ + image, + TextContentItem(text="Describe this image in two sentences."), + ] + ), ], stream=False, sampling_params=SamplingParams(max_tokens=100), @@ -89,8 +96,8 @@ class TestVisionModelInference: ) images = [ - ImageMedia( - image=URL( + ImageContentItem( + data=URL( uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" ) ), @@ -106,7 +113,12 @@ class TestVisionModelInference: messages=[ UserMessage(content="You are a helpful assistant."), UserMessage( - content=[image, "Describe this image in two sentences."] + content=[ + image, + TextContentItem( + text="Describe this image in two sentences." + ), + ] ), ], stream=True, diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index 3ca48d847..eb7f3a66b 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -7,8 +7,8 @@ import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import URL from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.common.deployment_types import URL from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index 3faea9f95..4e99a3daf 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -10,7 +10,7 @@ from urllib.parse import unquote import pandas -from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.deployment_types import URL from llama_stack.providers.utils.memory.vector_store import parse_data_url diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index b53f8cd32..5800bf0e0 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -7,9 +7,11 @@ import logging from typing import List -from llama_models.llama3.api.datatypes import InterleavedTextMedia - -from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore +from llama_stack.apis.inference import ( + EmbeddingsResponse, + InterleavedContent, + ModelStore, +) EMBEDDING_MODELS = {} @@ -23,7 +25,7 @@ class SentenceTransformerEmbeddingMixin: async def embeddings( self, model_id: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index fff0f539d..0f1e6894e 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -93,11 +93,15 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] - completion_message = formatter.decode_assistant_message_from_content( + raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) return ChatCompletionResponse( - completion_message=completion_message, + completion_message=CompletionMessage( + content=raw_message.content, + stop_reason=raw_message.stop_reason, + tool_calls=raw_message.tool_calls, + ), logprobs=None, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 5ff98f4ee..fb6a6dcfc 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -6,6 +6,7 @@ import asyncio import base64 +import io import json import logging import re @@ -21,7 +22,6 @@ from llama_models.llama3.api.datatypes import ( RawMediaItem, RawTextItem, Role, - ToolChoice, ToolPromptFormat, ) from llama_models.llama3.prompt_templates import ( @@ -47,6 +47,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, SystemMessage, TextContentItem, + ToolChoice, UserMessage, ) @@ -136,7 +137,7 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: if isinstance(media.data, URL) and media.data.uri.startswith("http"): async with httpx.AsyncClient() as client: - r = await client.get(media.image.uri) + r = await client.get(media.data.uri) content = r.content content_type = r.headers.get("content-type") if content_type: @@ -145,7 +146,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: format = "png" return content, format else: - image = PIL_Image.open(media.data) + image = PIL_Image.open(io.BytesIO(media.data)) return media.data, image.format @@ -153,7 +154,7 @@ async def convert_image_content_to_url( media: ImageContentItem, download: bool = False, include_format: bool = True ) -> str: if isinstance(media.data, URL) and not download: - return media.image.uri + return media.data.uri content, format = await localize_image_content(media) if include_format: diff --git a/llama_stack/providers/utils/memory/file_utils.py b/llama_stack/providers/utils/memory/file_utils.py index bc4462fa0..9ea3397fd 100644 --- a/llama_stack/providers/utils/memory/file_utils.py +++ b/llama_stack/providers/utils/memory/file_utils.py @@ -8,7 +8,7 @@ import base64 import mimetypes import os -from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.deployment_types import URL def data_url_from_file(file_path: str) -> URL: