mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 04:02:00 +00:00
Merge branch 'main' into out-of-token-budget-fix
This commit is contained in:
commit
85ef55391d
61 changed files with 1322 additions and 1598 deletions
|
|
@ -31,7 +31,7 @@ from llama_stack.apis.tools import ToolDef
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
|
|
@ -415,6 +415,7 @@ class Agents(Protocol):
|
|||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
|
|
@ -592,7 +593,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
@ -606,3 +607,4 @@ class Agents(Protocol):
|
|||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
|
|||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||
text: str
|
||||
|
|
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessage(BaseModel):
|
||||
id: str
|
||||
content: list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
status: str
|
||||
class OpenAIResponseMessage(BaseModel):
|
||||
"""
|
||||
Corresponds to the various Message types in the Responses API.
|
||||
They are all under one type because the Responses API gives them all
|
||||
the same "type" value, and there is no way to tell them apart in certain
|
||||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
# The fields below are not used in all scenarios, but are required in others.
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
|
|
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
arguments: str
|
||||
call_id: str
|
||||
name: str
|
||||
type: Literal["function_call"] = "function_call"
|
||||
id: str
|
||||
status: str
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
|
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||
"""
|
||||
This represents the output of a function call that gets passed back to the model.
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
output: str
|
||||
type: Literal["function_call_output"] = "function_call_output"
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
OpenAIResponseInput = Annotated[
|
||||
# Responses API allows output messages to be passed in as input
|
||||
OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
Field(union_mode="left_to_right"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessage(BaseModel):
|
||||
content: str | list[OpenAIResponseInputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] | None = "message"
|
||||
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
class FileSearchRankingOptions(BaseModel):
|
||||
ranker: str | None = None
|
||||
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_id: list[str]
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
# TODO: add filters
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
class OpenAIResponseInputItemList(BaseModel):
|
||||
data: list[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ class Eval(Protocol):
|
|||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
|
|
@ -112,6 +113,7 @@ class Eval(Protocol):
|
|||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
|
|
@ -140,3 +142,4 @@ class Eval(Protocol):
|
|||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -249,6 +249,10 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
tls_cafile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
|
||||
)
|
||||
auth: AuthenticationConfig | None = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class ProviderImpl(Providers):
|
|||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
|
|
|
|||
|
|
@ -630,7 +630,7 @@ class InferenceRouter(Inference):
|
|||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -17,6 +18,7 @@ from importlib.metadata import version as parse_version
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
|
|
@ -114,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=400, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, TimeoutError):
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
|
|
@ -139,7 +141,7 @@ async def shutdown(app):
|
|||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
|
@ -186,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
|
|||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
|
@ -484,7 +505,14 @@ def main(args: argparse.Namespace | None = None):
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not config.server.disable_ipv6 else "0.0.0.0"
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Session,
|
||||
|
|
@ -37,8 +37,8 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
|
@ -311,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
|
|||
|
|
@ -10,19 +10,26 @@ from collections.abc import AsyncIterator
|
|||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
|
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
|
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
# TODO: handle image content
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
)
|
||||
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""
|
||||
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||
"""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
return output_messages
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: OpenAIResponseInputItemList
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
|
|
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
|
|||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponseObject.model_validate_json(response_json)
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input_items.data
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.response.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(id)
|
||||
return response_with_input.response
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
|
|||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
if isinstance(user_input.content, list):
|
||||
for user_input_content in user_input.content:
|
||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
||||
if user_input_content.image_url:
|
||||
image_url = OpenAIImageURL(
|
||||
url=user_input_content.image_url, detail=user_input_content.detail
|
||||
)
|
||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
else:
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
||||
else:
|
||||
user_content = input
|
||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
||||
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
|
|
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
|
|||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
|
|
@ -163,7 +256,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call.model_dump())
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
|
|
@ -181,12 +293,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
|
@ -195,13 +321,43 @@ class OpenAIResponsesImpl:
|
|||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=response.model_dump_json(),
|
||||
value=prev_response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -221,7 +377,9 @@ class OpenAIResponsesImpl:
|
|||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
|
|
@ -247,12 +405,11 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
choice: OpenAIChoice,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
|
|
@ -262,6 +419,9 @@ class OpenAIResponsesImpl:
|
|||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Copy the messages list to avoid mutating the original list
|
||||
messages = messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
|
||||
|
|
@ -307,7 +467,9 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
raise ValueError(
|
||||
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
||||
)
|
||||
|
||||
query_config = query_config or RAGQueryConfig()
|
||||
query = await generate_rag_query(
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
|
@ -395,29 +396,25 @@ class OllamaInferenceAdapter(
|
|||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"prompt": prompt,
|
||||
"best_of": best_of,
|
||||
"echo": echo,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -447,35 +444,31 @@ class OllamaInferenceAdapter(
|
|||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"messages": messages,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"function_call": function_call,
|
||||
"functions": functions,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"parallel_tool_calls": parallel_tool_calls,
|
||||
"presence_penalty": presence_penalty,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"tool_choice": tool_choice,
|
||||
"tools": tools,
|
||||
"top_logprobs": top_logprobs,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def batch_completion(
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
||||
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
inference:
|
||||
tests:
|
||||
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
|
||||
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
|
||||
- inference/test_text_inference.py::test_structured_output
|
||||
- inference/test_text_inference.py::test_chat_completion_streaming
|
||||
- inference/test_text_inference.py::test_chat_completion_non_streaming
|
||||
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
|
||||
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
|
||||
|
||||
scenarios:
|
||||
- provider_fixtures:
|
||||
inference: ollama
|
||||
- fixture_combo_id: fireworks
|
||||
- provider_fixtures:
|
||||
inference: together
|
||||
# - inference: tgi
|
||||
# - inference: vllm_remote
|
||||
|
||||
inference_models:
|
||||
- meta-llama/Llama-3.1-8B-Instruct
|
||||
- meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
|
||||
|
||||
agents:
|
||||
tests:
|
||||
- agents/test_agents.py::test_agent_turns_with_safety
|
||||
- agents/test_agents.py::test_rag_agent
|
||||
|
||||
scenarios:
|
||||
- fixture_combo_id: ollama
|
||||
- fixture_combo_id: together
|
||||
- fixture_combo_id: fireworks
|
||||
|
||||
inference_models:
|
||||
- meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
safety_shield: meta-llama/Llama-Guard-3-1B
|
||||
|
||||
|
||||
memory:
|
||||
tests:
|
||||
- memory/test_memory.py::test_query_documents
|
||||
|
||||
scenarios:
|
||||
- fixture_combo_id: ollama
|
||||
- provider_fixtures:
|
||||
inference: sentence_transformers
|
||||
memory: faiss
|
||||
- fixture_combo_id: chroma
|
||||
|
||||
inference_models:
|
||||
- meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
|
|
@ -1,296 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
from .env import get_env_or_fail
|
||||
from .report import Report
|
||||
|
||||
|
||||
class ProviderFixture(BaseModel):
|
||||
providers: list[Provider]
|
||||
provider_data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TestScenario(BaseModel):
|
||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||
provider_fixtures: dict[str, str] = Field(default_factory=dict)
|
||||
fixture_combo_id: str | None = None
|
||||
|
||||
|
||||
class APITestConfig(BaseModel):
|
||||
scenarios: list[TestScenario] = Field(default_factory=list)
|
||||
inference_models: list[str] = Field(default_factory=list)
|
||||
|
||||
# test name format should be <relative_path.py>::<test_name>
|
||||
tests: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryApiTestConfig(APITestConfig):
|
||||
embedding_model: str | None = Field(default_factory=None)
|
||||
|
||||
|
||||
class AgentsApiTestConfig(APITestConfig):
|
||||
safety_shield: str | None = Field(default_factory=None)
|
||||
|
||||
|
||||
class TestConfig(BaseModel):
|
||||
inference: APITestConfig | None = None
|
||||
agents: AgentsApiTestConfig | None = None
|
||||
memory: MemoryApiTestConfig | None = None
|
||||
|
||||
|
||||
def get_test_config_from_config_file(metafunc_config):
|
||||
config_file = metafunc_config.getoption("--config")
|
||||
if config_file is None:
|
||||
return None
|
||||
|
||||
config_file_path = Path(__file__).parent / config_file
|
||||
if not config_file_path.exists():
|
||||
raise ValueError(
|
||||
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||
)
|
||||
with open(config_file_path) as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
return TestConfig(**config)
|
||||
|
||||
|
||||
def get_test_config_for_api(metafunc_config, api):
|
||||
test_config = get_test_config_from_config_file(metafunc_config)
|
||||
if test_config is None:
|
||||
return None
|
||||
return getattr(test_config, api)
|
||||
|
||||
|
||||
def get_provider_fixture_overrides_from_test_config(metafunc_config, api, default_provider_fixture_combinations):
|
||||
api_config = get_test_config_for_api(metafunc_config, api)
|
||||
if api_config is None:
|
||||
return None
|
||||
|
||||
fixture_combo_ids = set()
|
||||
custom_provider_fixture_combos = []
|
||||
for scenario in api_config.scenarios:
|
||||
if scenario.fixture_combo_id:
|
||||
fixture_combo_ids.add(scenario.fixture_combo_id)
|
||||
else:
|
||||
custom_provider_fixture_combos.append(
|
||||
pytest.param(
|
||||
scenario.provider_fixtures,
|
||||
id=scenario.provider_fixtures.get("inference") or "",
|
||||
)
|
||||
)
|
||||
|
||||
if len(fixture_combo_ids) > 0:
|
||||
for default_fixture in default_provider_fixture_combinations:
|
||||
if default_fixture.id in fixture_combo_ids:
|
||||
custom_provider_fixture_combos.append(default_fixture)
|
||||
return custom_provider_fixture_combos
|
||||
|
||||
|
||||
def remote_stack_fixture() -> ProviderFixture:
|
||||
if url := os.getenv("REMOTE_STACK_URL", None):
|
||||
config = RemoteProviderConfig.from_url(url)
|
||||
else:
|
||||
config = RemoteProviderConfig(
|
||||
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
||||
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
||||
)
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="test::remote",
|
||||
provider_type="test::remote",
|
||||
config=config.model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.option.tbstyle = "short"
|
||||
config.option.disable_warnings = True
|
||||
|
||||
"""Load environment variables at start of test run"""
|
||||
# Load from .env file if it exists
|
||||
env_file = Path(__file__).parent / ".env"
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
|
||||
# Load any environment variables passed via --env
|
||||
env_vars = config.getoption("--env") or []
|
||||
for env_var in env_vars:
|
||||
key, value = env_var.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
if config.getoption("--output") is not None:
|
||||
config.pluginmanager.register(Report(config.getoption("--output")))
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--providers",
|
||||
default="",
|
||||
help=(
|
||||
"Provider configuration in format: api1=provider1,api2=provider2. "
|
||||
"Example: --providers inference=ollama,safety=meta-reference"
|
||||
),
|
||||
)
|
||||
parser.addoption(
|
||||
"--config",
|
||||
action="store",
|
||||
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
|
||||
)
|
||||
parser.addoption(
|
||||
"--output",
|
||||
action="store",
|
||||
help="Set output file for test report, e.g. --output=pytest_report.md",
|
||||
)
|
||||
"""Add custom command line options"""
|
||||
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--judge-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="Specify the judge model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def make_provider_id(providers: dict[str, str]) -> str:
|
||||
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
|
||||
|
||||
|
||||
def get_provider_marks(providers: dict[str, str]) -> list[Any]:
|
||||
marks = []
|
||||
for provider in providers.values():
|
||||
marks.append(getattr(pytest.mark, provider))
|
||||
return marks
|
||||
|
||||
|
||||
def get_provider_fixture_overrides(config, available_fixtures: dict[str, list[str]]) -> list[pytest.param] | None:
|
||||
provider_str = config.getoption("--providers")
|
||||
if not provider_str:
|
||||
return None
|
||||
|
||||
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
|
||||
return [
|
||||
pytest.param(
|
||||
fixture_dict,
|
||||
id=make_provider_id(fixture_dict),
|
||||
marks=get_provider_marks(fixture_dict),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def parse_fixture_string(provider_str: str, available_fixtures: dict[str, list[str]]) -> dict[str, str]:
|
||||
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
|
||||
if not provider_str:
|
||||
return {}
|
||||
|
||||
fixtures = {}
|
||||
pairs = provider_str.split(",")
|
||||
for pair in pairs:
|
||||
if "=" not in pair:
|
||||
raise ValueError(f"Invalid provider specification: {pair}. Expected format: api=provider")
|
||||
api, fixture = pair.split("=")
|
||||
if api not in available_fixtures:
|
||||
raise ValueError(f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}")
|
||||
if fixture not in available_fixtures[api]:
|
||||
raise ValueError(
|
||||
f"Unknown provider '{fixture}' for API '{api}'. Available providers: {list(available_fixtures[api])}"
|
||||
)
|
||||
fixtures[api] = fixture
|
||||
|
||||
# Check that all provided APIs are supported
|
||||
for api in available_fixtures.keys():
|
||||
if api not in fixtures:
|
||||
raise ValueError(
|
||||
f"Missing provider fixture for API '{api}'. Available providers: {list(available_fixtures[api])}"
|
||||
)
|
||||
return fixtures
|
||||
|
||||
|
||||
def pytest_itemcollected(item):
|
||||
# Get all markers as a list
|
||||
filtered = ("asyncio", "parametrize")
|
||||
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
|
||||
if marks:
|
||||
marks = colored(",".join(marks), "yellow")
|
||||
item.name = f"{item.name}[{marks}]"
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
test_config = get_test_config_from_config_file(config)
|
||||
if test_config is None:
|
||||
return
|
||||
|
||||
required_tests = defaultdict(set)
|
||||
for api_test_config in [
|
||||
test_config.inference,
|
||||
test_config.memory,
|
||||
test_config.agents,
|
||||
]:
|
||||
if api_test_config is None:
|
||||
continue
|
||||
for test in api_test_config.tests:
|
||||
arr = test.split("::")
|
||||
if len(arr) != 2:
|
||||
raise ValueError(f"Invalid format for test name {test}")
|
||||
test_path, func_name = arr
|
||||
required_tests[Path(__file__).parent / test_path].add(func_name)
|
||||
|
||||
new_items, deselected_items = [], []
|
||||
for item in items:
|
||||
func_name = getattr(item, "originalname", item.name)
|
||||
if func_name in required_tests[item.fspath]:
|
||||
new_items.append(item)
|
||||
continue
|
||||
deselected_items.append(item)
|
||||
|
||||
items[:] = new_items
|
||||
config.hook.pytest_deselected(items=deselected_items)
|
||||
|
||||
|
||||
pytest_plugins = [
|
||||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
"llama_stack.providers.tests.vector_io.fixtures",
|
||||
"llama_stack.providers.tests.agents.fixtures",
|
||||
"llama_stack.providers.tests.datasetio.fixtures",
|
||||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
"llama_stack.providers.tests.eval.fixtures",
|
||||
"llama_stack.providers.tests.post_training.fixtures",
|
||||
"llama_stack.providers.tests.tools.fixtures",
|
||||
]
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pytest import ExitCode
|
||||
from pytest_html.basereport import _process_outcome
|
||||
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
|
||||
INFERENCE_APIS = ["chat_completion"]
|
||||
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
|
||||
SUPPORTED_MODELS = {
|
||||
"ollama": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
},
|
||||
"fireworks": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
},
|
||||
"together": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Report:
|
||||
def __init__(self, output_path):
|
||||
valid_file_format = (
|
||||
output_path.split(".")[1] in ["md", "markdown"] if len(output_path.split(".")) == 2 else False
|
||||
)
|
||||
if not valid_file_format:
|
||||
raise ValueError(f"Invalid output file {output_path}. Markdown file is required")
|
||||
self.output_path = output_path
|
||||
self.test_data = defaultdict(dict)
|
||||
self.inference_tests = defaultdict(dict)
|
||||
|
||||
@pytest.hookimpl
|
||||
def pytest_runtest_logreport(self, report):
|
||||
# This hook is called in several phases, including setup, call and teardown
|
||||
# The test is considered failed / error if any of the outcomes is not "Passed"
|
||||
outcome = _process_outcome(report)
|
||||
data = {
|
||||
"outcome": report.outcome,
|
||||
"longrepr": report.longrepr,
|
||||
"name": report.nodeid,
|
||||
}
|
||||
if report.nodeid not in self.test_data:
|
||||
self.test_data[report.nodeid] = data
|
||||
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
|
||||
self.test_data[report.nodeid] = data
|
||||
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
if exitstatus <= ExitCode.INTERRUPTED:
|
||||
return
|
||||
report = []
|
||||
report.append("# Llama Stack Integration Test Results Report")
|
||||
report.append("\n## Summary")
|
||||
report.append("\n## Supported Models: ")
|
||||
|
||||
header = "| Model Descriptor |"
|
||||
dividor = "|:---|"
|
||||
for k in SUPPORTED_MODELS.keys():
|
||||
header += f"{k} |"
|
||||
dividor += ":---:|"
|
||||
|
||||
report.append(header)
|
||||
report.append(dividor)
|
||||
|
||||
rows = []
|
||||
for model in all_registered_models():
|
||||
if "Instruct" not in model.core_model_id.value and "Guard" not in model.core_model_id.value:
|
||||
continue
|
||||
row = f"| {model.core_model_id.value} |"
|
||||
for k in SUPPORTED_MODELS.keys():
|
||||
if model.core_model_id.value in SUPPORTED_MODELS[k]:
|
||||
row += " ✅ |"
|
||||
else:
|
||||
row += " ❌ |"
|
||||
rows.append(row)
|
||||
report.extend(rows)
|
||||
|
||||
report.append("\n### Tests:")
|
||||
|
||||
for provider in SUPPORTED_MODELS.keys():
|
||||
if provider not in self.inference_tests:
|
||||
continue
|
||||
report.append(f"\n #### {provider}")
|
||||
test_table = [
|
||||
"| Area | Model | API | Functionality Test | Status |",
|
||||
"|:-----|:-----|:-----|:-----|:-----|",
|
||||
]
|
||||
for api in INFERENCE_APIS:
|
||||
tests = self.inference_tests[provider][api]
|
||||
for test_nodeid in tests:
|
||||
row = "|{area} | {model} | {api} | {test} | {result} ".format(
|
||||
area="Text" if "text" in test_nodeid else "Vision",
|
||||
model=("Llama-3.1-8B-Instruct" if "text" in test_nodeid else "Llama3.2-11B-Vision-Instruct"),
|
||||
api=f"/{api}",
|
||||
test=self.get_simple_function_name(test_nodeid),
|
||||
result=("✅" if self.test_data[test_nodeid]["outcome"] == "passed" else "❌"),
|
||||
)
|
||||
test_table += [row]
|
||||
report.extend(test_table)
|
||||
report.append("\n")
|
||||
|
||||
output_file = Path(self.output_path)
|
||||
output_file.write_text("\n".join(report))
|
||||
print(f"\n Report generated: {output_file.absolute()}")
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_collection_modifyitems(self, session, config, items):
|
||||
for item in items:
|
||||
inference = item.callspec.params.get("inference_stack")
|
||||
if "inference" in item.nodeid:
|
||||
func_name = getattr(item, "originalname", item.name)
|
||||
for api in INFERENCE_APIS:
|
||||
if api in func_name:
|
||||
api_tests = self.inference_tests[inference].get(api, set())
|
||||
api_tests.add(item.nodeid)
|
||||
self.inference_tests[inference][api] = api_tests
|
||||
|
||||
def get_simple_function_name(self, nodeid):
|
||||
"""Extract function name from nodeid.
|
||||
|
||||
Examples:
|
||||
- 'tests/test_math.py::test_addition' -> 'test_addition'
|
||||
- 'tests/test_math.py::TestClass::test_method' -> test_method'
|
||||
"""
|
||||
parts = nodeid.split("::")
|
||||
func_name = nodeid # Fallback to full nodeid if pattern doesn't match
|
||||
if len(parts) == 2: # Simple function
|
||||
func_name = parts[1]
|
||||
elif len(parts) == 3: # Class method
|
||||
func_name = parts[2]
|
||||
return func_name.split("[")[0]
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
# Report for cerebras distribution
|
||||
|
||||
## Supported Models
|
||||
| Model Descriptor | cerebras |
|
||||
|:---|:---|
|
||||
| meta-llama/Llama-3-8B-Instruct | ❌ |
|
||||
| meta-llama/Llama-3-70B-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.1-8B-Instruct | ✅ |
|
||||
| meta-llama/Llama-3.1-70B-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.1-405B-Instruct-FP8 | ❌ |
|
||||
| meta-llama/Llama-3.2-1B-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.2-3B-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.2-11B-Vision-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.2-90B-Vision-Instruct | ❌ |
|
||||
| meta-llama/Llama-3.3-70B-Instruct | ✅ |
|
||||
| meta-llama/Llama-Guard-3-11B-Vision | ❌ |
|
||||
| meta-llama/Llama-Guard-3-1B | ❌ |
|
||||
| meta-llama/Llama-Guard-3-8B | ❌ |
|
||||
| meta-llama/Llama-Guard-2-8B | ❌ |
|
||||
|
||||
## Inference
|
||||
| Model | API | Capability | Test | Status |
|
||||
|:----- |:-----|:-----|:-----|:-----|
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ❌ |
|
||||
|
||||
## Vector IO
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /retrieve | | test_vector_db_retrieve | ✅ |
|
||||
|
||||
## Agents
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /create_agent_turn | rag | test_rag_agent | ✅ |
|
||||
| /create_agent_turn | custom_tool | test_custom_tool | ❌ |
|
||||
|
|
@ -833,6 +833,8 @@
|
|||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
# Report for fireworks distribution
|
||||
|
||||
## Supported Models
|
||||
| Model Descriptor | fireworks |
|
||||
|:---|:---|
|
||||
| Llama-3-8B-Instruct | ❌ |
|
||||
| Llama-3-70B-Instruct | ❌ |
|
||||
| Llama3.1-8B-Instruct | ✅ |
|
||||
| Llama3.1-70B-Instruct | ✅ |
|
||||
| Llama3.1-405B-Instruct | ✅ |
|
||||
| Llama3.2-1B-Instruct | ✅ |
|
||||
| Llama3.2-3B-Instruct | ✅ |
|
||||
| Llama3.2-11B-Vision-Instruct | ✅ |
|
||||
| Llama3.2-90B-Vision-Instruct | ✅ |
|
||||
| Llama3.3-70B-Instruct | ✅ |
|
||||
| Llama-Guard-3-11B-Vision | ✅ |
|
||||
| Llama-Guard-3-1B | ❌ |
|
||||
| Llama-Guard-3-8B | ✅ |
|
||||
| Llama-Guard-2-8B | ❌ |
|
||||
|
||||
## Inference
|
||||
| Model | API | Capability | Test | Status |
|
||||
|:----- |:-----|:-----|:-----|:-----|
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | log_probs | test_completion_log_probs_non_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | log_probs | test_completion_log_probs_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ |
|
||||
|
||||
## Vector IO
|
||||
| Provider | API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|:-----|
|
||||
| inline::faiss | /retrieve | | test_vector_db_retrieve | ✅ |
|
||||
|
||||
## Agents
|
||||
| Provider | API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|:-----|
|
||||
| inline::meta-reference | /create_agent_turn | rag | test_rag_agent | ✅ |
|
||||
| inline::meta-reference | /create_agent_turn | custom_tool | test_custom_tool | ✅ |
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
# Report for ollama distribution
|
||||
|
||||
## Supported Models
|
||||
| Model Descriptor | ollama |
|
||||
|:---|:---|
|
||||
| Llama-3-8B-Instruct | ❌ |
|
||||
| Llama-3-70B-Instruct | ❌ |
|
||||
| Llama3.1-8B-Instruct | ✅ |
|
||||
| Llama3.1-70B-Instruct | ✅ |
|
||||
| Llama3.1-405B-Instruct | ✅ |
|
||||
| Llama3.2-1B-Instruct | ✅ |
|
||||
| Llama3.2-3B-Instruct | ✅ |
|
||||
| Llama3.2-11B-Vision-Instruct | ✅ |
|
||||
| Llama3.2-90B-Vision-Instruct | ✅ |
|
||||
| Llama3.3-70B-Instruct | ✅ |
|
||||
| Llama-Guard-3-11B-Vision | ❌ |
|
||||
| Llama-Guard-3-1B | ✅ |
|
||||
| Llama-Guard-3-8B | ✅ |
|
||||
| Llama-Guard-2-8B | ❌ |
|
||||
|
||||
## Inference
|
||||
| Model | API | Capability | Test | Status |
|
||||
|:----- |:-----|:-----|:-----|:-----|
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ |
|
||||
|
||||
## Vector IO
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /retrieve | | test_vector_db_retrieve | ✅ |
|
||||
|
||||
## Agents
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /create_agent_turn | rag | test_rag_agent | ✅ |
|
||||
| /create_agent_turn | custom_tool | test_custom_tool | ✅ |
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
# Report for tgi distribution
|
||||
|
||||
## Supported Models
|
||||
| Model Descriptor | tgi |
|
||||
|:---|:---|
|
||||
| Llama-3-8B-Instruct | ✅ |
|
||||
| Llama-3-70B-Instruct | ✅ |
|
||||
| Llama3.1-8B-Instruct | ✅ |
|
||||
| Llama3.1-70B-Instruct | ✅ |
|
||||
| Llama3.1-405B-Instruct | ✅ |
|
||||
| Llama3.2-1B-Instruct | ✅ |
|
||||
| Llama3.2-3B-Instruct | ✅ |
|
||||
| Llama3.2-11B-Vision-Instruct | ✅ |
|
||||
| Llama3.2-90B-Vision-Instruct | ✅ |
|
||||
| Llama3.3-70B-Instruct | ✅ |
|
||||
| Llama-Guard-3-11B-Vision | ✅ |
|
||||
| Llama-Guard-3-1B | ✅ |
|
||||
| Llama-Guard-3-8B | ✅ |
|
||||
| Llama-Guard-2-8B | ✅ |
|
||||
|
||||
## Inference
|
||||
| Model | API | Capability | Test | Status |
|
||||
|:----- |:-----|:-----|:-----|:-----|
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ |
|
||||
|
||||
## Vector IO
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /retrieve | | test_vector_db_retrieve | ✅ |
|
||||
|
||||
## Agents
|
||||
| API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|
|
||||
| /create_agent_turn | rag | test_rag_agent | ✅ |
|
||||
| /create_agent_turn | custom_tool | test_custom_tool | ✅ |
|
||||
| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ |
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
# Report for together distribution
|
||||
|
||||
## Supported Models
|
||||
| Model Descriptor | together |
|
||||
|:---|:---|
|
||||
| Llama-3-8B-Instruct | ❌ |
|
||||
| Llama-3-70B-Instruct | ❌ |
|
||||
| Llama3.1-8B-Instruct | ✅ |
|
||||
| Llama3.1-70B-Instruct | ✅ |
|
||||
| Llama3.1-405B-Instruct | ✅ |
|
||||
| Llama3.2-1B-Instruct | ❌ |
|
||||
| Llama3.2-3B-Instruct | ✅ |
|
||||
| Llama3.2-11B-Vision-Instruct | ✅ |
|
||||
| Llama3.2-90B-Vision-Instruct | ✅ |
|
||||
| Llama3.3-70B-Instruct | ✅ |
|
||||
| Llama-Guard-3-11B-Vision | ✅ |
|
||||
| Llama-Guard-3-1B | ❌ |
|
||||
| Llama-Guard-3-8B | ✅ |
|
||||
| Llama-Guard-2-8B | ❌ |
|
||||
|
||||
## Inference
|
||||
| Model | API | Capability | Test | Status |
|
||||
|:----- |:-----|:-----|:-----|:-----|
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | log_probs | test_completion_log_probs_non_streaming | ✅ |
|
||||
| Llama-3.2-11B-Vision-Instruct | /chat_completion | log_probs | test_completion_log_probs_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
|
||||
| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ |
|
||||
|
||||
## Vector IO
|
||||
| Provider | API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|:-----|
|
||||
| inline::faiss | /retrieve | | test_vector_db_retrieve | ✅ |
|
||||
|
||||
## Agents
|
||||
| Provider | API | Capability | Test | Status |
|
||||
|:-----|:-----|:-----|:-----|:-----|
|
||||
| inline::meta-reference | /create_agent_turn | rag | test_rag_agent | ✅ |
|
||||
| inline::meta-reference | /create_agent_turn | custom_tool | test_custom_tool | ✅ |
|
||||
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference:
|
||||
- remote::watsonx
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ providers:
|
|||
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -191,6 +194,11 @@ models:
|
|||
provider_id: watsonx
|
||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
|
|
@ -14,7 +18,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
|||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::watsonx"],
|
||||
"inference": ["remote::watsonx", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
|
|
@ -36,6 +40,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
|
|
@ -50,6 +60,15 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="watsonx",
|
||||
|
|
@ -62,9 +81,9 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue