mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:22:25 +00:00
Merge branch 'rag_scoring_fn_1' into rag_scoring_fn_2
This commit is contained in:
commit
dbecff60a4
128 changed files with 6391 additions and 493 deletions
|
|
@ -4,17 +4,28 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Message,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
|
@ -94,9 +93,14 @@ async def convert_request_to_raw(
|
|||
d = m.model_dump()
|
||||
d["content"] = content
|
||||
messages.append(RawMessage(**d))
|
||||
request.messages = messages
|
||||
|
||||
d = request.model_dump()
|
||||
d["messages"] = messages
|
||||
request = ChatCompletionRequestWithRawContent(**d)
|
||||
else:
|
||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||
d = request.model_dump()
|
||||
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
||||
request = CompletionRequestWithRawContent(**d)
|
||||
|
||||
return request
|
||||
|
||||
|
|
@ -112,27 +116,31 @@ async def interleaved_content_convert_to_raw(
|
|||
elif isinstance(c, TextContentItem):
|
||||
return RawTextItem(text=c.text)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
# load image and return PIL version
|
||||
img = c.data
|
||||
if isinstance(img, URL):
|
||||
if img.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
|
||||
if c.url:
|
||||
# Load image bytes from URL
|
||||
if c.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
|
||||
if not match:
|
||||
raise ValueError("Invalid data URL format")
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {c.url.uri[:40]}..."
|
||||
)
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif img.uri.startswith("file://"):
|
||||
path = img.uri[len("file://") :]
|
||||
elif c.url.uri.startswith("file://"):
|
||||
path = c.url.uri[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
data = f.read() # type: ignore
|
||||
elif img.uri.startswith("http"):
|
||||
elif c.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(img.uri)
|
||||
response = await client.get(c.url.uri)
|
||||
data = response.content
|
||||
else:
|
||||
raise ValueError("Unsupported URL type")
|
||||
else:
|
||||
elif c.data:
|
||||
data = c.data
|
||||
else:
|
||||
raise ValueError("No data or URL provided")
|
||||
|
||||
return RawMediaItem(data=data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||
|
|
@ -277,7 +285,8 @@ def chat_completion_request_to_messages(
|
|||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family == ModelFamily.llama3_2:
|
||||
elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .config import * # noqa: F403
|
||||
from typing import List, Optional
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig, KVStoreType
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import List, Optional
|
|||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from ..api import * # noqa: F403
|
||||
from ..api import KVStore
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import List, Optional
|
|||
|
||||
import aiosqlite
|
||||
|
||||
from ..api import * # noqa: F403
|
||||
from ..api import KVStore
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,17 @@ from urllib.parse import unquote
|
|||
import chardet
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from numpy.typing import NDArray
|
||||
from pypdf import PdfReader
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -12,10 +12,18 @@ import threading
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import (
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue