Update the "InterleavedTextMedia" type (#635)

## What does this PR do?

This is a long-pending change and particularly important to get done
now.

Specifically:
- we cannot "localize" (aka download) any URLs from media attachments
anywhere near our modeling code. it must be done within llama-stack.
- `PIL.Image` is infesting all our APIs via `ImageMedia ->
InterleavedTextMedia` and that cannot be right at all. Anything in the
API surface must be "naturally serializable". We need a standard `{
type: "image", image_url: "<...>" }` which is more extensible
- `UserMessage`, `SystemMessage`, etc. are moved completely to
llama-stack from the llama-models repository.

See https://github.com/meta-llama/llama-models/pull/244 for the
corresponding PR in llama-models.

## Test Plan

```bash
cd llama_stack/providers/tests

pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py
pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py
pytest -s -v -k chroma memory/test_memory.py \
  --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar

pytest -s -v -k fireworks agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct
```

Updated the client sdk (see PR ...), installed the SDK in the same
environment and then ran the SDK tests:

```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py
LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py

# this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly
INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py
```
This commit is contained in:
Ashwin Bharambe 2024-12-17 11:18:31 -08:00 committed by GitHub
parent 10eb31badf
commit 8de8eb03c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 1344 additions and 1801 deletions

View file

@ -7,25 +7,60 @@
import asyncio
import logging
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
RawMessage,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
interleaved_content_convert_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .generation import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
Llama,
)
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -90,7 +125,7 @@ class MetaReferenceInferenceImpl(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -99,6 +134,7 @@ class MetaReferenceInferenceImpl(
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
@ -108,7 +144,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
@ -233,7 +269,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -274,11 +316,15 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
@ -406,29 +452,18 @@ class MetaReferenceInferenceImpl(
yield x
async def request_with_localized_media(
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
m.content = await _convert_content(m.content)
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
else:
request.content = await _convert_content(request.content)
request.content = await interleaved_content_convert_to_raw(request.content)
return request