Update the "InterleavedTextMedia" type (#635)

## What does this PR do?

This is a long-pending change and particularly important to get done
now.

Specifically:
- we cannot "localize" (aka download) any URLs from media attachments
anywhere near our modeling code. it must be done within llama-stack.
- `PIL.Image` is infesting all our APIs via `ImageMedia ->
InterleavedTextMedia` and that cannot be right at all. Anything in the
API surface must be "naturally serializable". We need a standard `{
type: "image", image_url: "<...>" }` which is more extensible
- `UserMessage`, `SystemMessage`, etc. are moved completely to
llama-stack from the llama-models repository.

See https://github.com/meta-llama/llama-models/pull/244 for the
corresponding PR in llama-models.

## Test Plan

```bash
cd llama_stack/providers/tests

pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py
pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py
pytest -s -v -k chroma memory/test_memory.py \
  --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar

pytest -s -v -k fireworks agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct
```

Updated the client sdk (see PR ...), installed the SDK in the same
environment and then ran the SDK tests:

```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py
LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py

# this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly
INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py
```
This commit is contained in:
Ashwin Bharambe 2024-12-17 11:18:31 -08:00 committed by GitHub
parent 10eb31badf
commit 8de8eb03c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 1344 additions and 1801 deletions

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
@ -389,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
if rag_context:
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
@ -655,7 +656,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
@ -723,11 +724,16 @@ class ChatAgent(ShieldRunnerMixin):
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids
return (
concat_interleaved_content(
[
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]:
ret = []

View file

@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
@ -42,7 +45,7 @@ async def default_rag_query_generator(
messages: List[Message],
**kwargs,
):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(

View file

@ -9,8 +9,6 @@ import logging
from typing import List
from llama_models.llama3.api.datatypes import Message
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)

View file

@ -36,7 +36,7 @@ def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
)
return None

View file

@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import RawContent, RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from .config import (
Fp8QuantizationConfig,
@ -53,6 +50,14 @@ from .config import (
log = logging.getLogger(__name__)
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -206,7 +211,7 @@ class Llama:
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
@ -343,7 +348,7 @@ class Llama:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
@ -354,10 +359,7 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -374,10 +376,8 @@ class Llama:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
@ -389,7 +389,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
request.messages,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,

View file

@ -7,25 +7,60 @@
import asyncio
import logging
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
RawMessage,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
interleaved_content_convert_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .generation import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
Llama,
)
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -90,7 +125,7 @@ class MetaReferenceInferenceImpl(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -99,6 +134,7 @@ class MetaReferenceInferenceImpl(
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
@ -108,7 +144,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
@ -233,7 +269,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -274,11 +316,15 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
@ -406,29 +452,18 @@ class MetaReferenceInferenceImpl(
yield x
async def request_with_localized_media(
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
m.content = await _convert_content(m.content)
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
else:
request.content = await _convert_content(request.content)
request.content = await interleaved_content_convert_to_raw(request.content)
return request

View file

@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia]
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
async def get_provider_impl(
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -19,9 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)

View file

@ -7,13 +7,17 @@
import logging
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
@ -48,7 +52,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)

View file

@ -12,9 +12,13 @@ from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig
@ -258,18 +262,18 @@ class LlamaGuardShield:
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
elif isinstance(c, ImageMedia):
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
@ -292,7 +296,7 @@ class LlamaGuardShield:
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import PromptGuardConfig, PromptGuardType
@ -83,7 +86,7 @@ class PromptGuardShield:
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")