Merge branch 'rag_scoring_fn_1' into rag_scoring_fn_2

This commit is contained in:
Xi Yan 2024-12-30 17:20:35 -08:00
commit dbecff60a4
128 changed files with 6391 additions and 493 deletions

View file

@ -4,17 +4,28 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Message,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)

View file

@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
@ -94,9 +93,14 @@ async def convert_request_to_raw(
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
request.content = await interleaved_content_convert_to_raw(request.content)
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
@ -112,27 +116,31 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
# load image and return PIL version
img = c.data
if isinstance(img, URL):
if img.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
if c.url:
# Load image bytes from URL
if c.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
if not match:
raise ValueError("Invalid data URL format")
raise ValueError(
f"Invalid data URL format, {c.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif img.uri.startswith("file://"):
path = img.uri[len("file://") :]
elif c.url.uri.startswith("file://"):
path = c.url.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif img.uri.startswith("http"):
elif c.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(img.uri)
response = await client.get(c.url.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
else:
elif c.data:
data = c.data
else:
raise ValueError("No data or URL provided")
return RawMediaItem(data=data)
else:
raise ValueError(f"Unsupported content type: {type(c)}")
@ -277,7 +285,8 @@ def chat_completion_request_to_messages(
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family == ModelFamily.llama3_2:
elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
# llama3.2 and llama3.3 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
else:
messages = request.messages

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F403
from .config import * # noqa: F403
from typing import List, Optional
from .api import KVStore
from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies():

View file

@ -9,7 +9,7 @@ from typing import List, Optional
from redis.asyncio import Redis
from ..api import * # noqa: F403
from ..api import KVStore
from ..config import RedisKVStoreConfig

View file

@ -11,7 +11,7 @@ from typing import List, Optional
import aiosqlite
from ..api import * # noqa: F403
from ..api import KVStore
from ..config import SqliteKVStoreConfig

View file

@ -15,14 +15,17 @@ from urllib.parse import unquote
import chardet
import httpx
import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (

View file

@ -6,7 +6,8 @@
import statistics
from typing import Any, Dict, List
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:

View file

@ -12,10 +12,18 @@ import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.telemetry import (
LogSeverity,
Span,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
StructuredLogEvent,
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__)