Update the "InterleavedTextMedia" type (#635)

## What does this PR do?

This is a long-pending change and particularly important to get done
now.

Specifically:
- we cannot "localize" (aka download) any URLs from media attachments
anywhere near our modeling code. it must be done within llama-stack.
- `PIL.Image` is infesting all our APIs via `ImageMedia ->
InterleavedTextMedia` and that cannot be right at all. Anything in the
API surface must be "naturally serializable". We need a standard `{
type: "image", image_url: "<...>" }` which is more extensible
- `UserMessage`, `SystemMessage`, etc. are moved completely to
llama-stack from the llama-models repository.

See https://github.com/meta-llama/llama-models/pull/244 for the
corresponding PR in llama-models.

## Test Plan

```bash
cd llama_stack/providers/tests

pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py
pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py
pytest -s -v -k chroma memory/test_memory.py \
  --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar

pytest -s -v -k fireworks agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct
```

Updated the client sdk (see PR ...), installed the SDK in the same
environment and then ran the SDK tests:

```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py
LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py

# this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly
INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py
```
This commit is contained in:
Ashwin Bharambe 2024-12-17 11:18:31 -08:00 committed by GitHub
parent 10eb31badf
commit 8de8eb03c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 1344 additions and 1801 deletions

View file

@ -13,10 +13,19 @@ import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
@ -66,7 +75,7 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
@ -112,31 +121,17 @@ def stream_across_asyncio_run_boundary(
future.result()
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from
# generated client-sdk code (using ApiResponse.parse() essentially)
value_dict = json.loads(value.model_dump_json())
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value
return json.loads(value.model_dump_json())
else:
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
@ -278,16 +273,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream:
return self._call_streaming(options.url, params, cast_to)
return self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
return await self._call_non_streaming(options.url, params, cast_to)
return await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
self,
*,
cast_to: Any,
options: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -295,11 +302,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
async def _call_streaming(
self,
*,
cast_to: Any,
options: Any,
stream_cls: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -307,8 +348,42 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
finally:
await end_trace()

View file

@ -59,7 +59,7 @@ class MemoryRouter(Memory):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
@ -133,7 +133,7 @@ class InferenceRouter(Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -163,7 +163,7 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:

View file

@ -16,8 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry
@ -30,7 +29,6 @@ def get_impl_api(p: Any) -> Api:
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -76,7 +74,6 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:

View file

@ -6,6 +6,7 @@
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict
@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val
value = default_val if default_val != "null" else None
# expand "~" from the values
return os.path.expanduser(value)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(value),
)
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
return all_objects
@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
objects_data = json.loads(json_str)
# Return only the first object if any exist
if objects_data:
return pydantic.parse_obj_as(
RoutableObjectWithProvider,
json.loads(objects_data),
)
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(