mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? - as title, cleaning up `import *`'s - upgrade tests to make them more robust to bad model outputs - remove import *'s in llama_stack/apis/* (skip __init__ modules) <img width="465" alt="image" src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2" /> - run `sh run_openapi_generator.sh`, no types gets affected ## Test Plan ### Providers Tests **agents** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` **inference** ```bash # meta-reference torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py # together pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py ``` **safety** ``` pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B ``` **memory** ``` pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 ``` **scoring** ``` pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` **datasetio** ``` pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py ``` **eval** ``` pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py ``` ### Client-SDK Tests ``` LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk ``` ### llama-stack-apps ``` PORT=5000 LOCALHOST=localhost python -m examples.agents.hello $LOCALHOST $PORT python -m examples.agents.inflation $LOCALHOST $PORT python -m examples.agents.podcast_transcript $LOCALHOST $PORT python -m examples.agents.rag_as_attachments $LOCALHOST $PORT python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT # Vision model python -m examples.interior_design_assistant.app python -m examples.agent_store.app $LOCALHOST $PORT ``` ### CLI ``` which llama llama model prompt-format -m Llama3.2-11B-Vision-Instruct llama model list llama stack list-apis llama stack list-providers inference llama stack build --template ollama --image-type conda ``` ### Distributions Tests **ollama** ``` llama stack build --template ollama --image-type conda ollama run llama3.2:1b-instruct-fp16 llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct ``` **fireworks** ``` llama stack build --template fireworks --image-type conda llama stack run ./llama_stack/templates/fireworks/run.yaml ``` **together** ``` llama stack build --template together --image-type conda llama stack run ./llama_stack/templates/together/run.yaml ``` **tgi** ``` llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
350 lines
8.7 KiB
Python
350 lines
8.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
BuiltinTool,
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
|
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
|
|
from llama_stack.apis.models import Model
|
|
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
|
|
|
|
class LogProbConfig(BaseModel):
|
|
top_k: Optional[int] = 0
|
|
|
|
|
|
@json_schema_type
|
|
class QuantizationType(Enum):
|
|
bf16 = "bf16"
|
|
fp8 = "fp8"
|
|
int4 = "int4"
|
|
|
|
|
|
@json_schema_type
|
|
class Fp8QuantizationConfig(BaseModel):
|
|
type: Literal["fp8"] = "fp8"
|
|
|
|
|
|
@json_schema_type
|
|
class Bf16QuantizationConfig(BaseModel):
|
|
type: Literal["bf16"] = "bf16"
|
|
|
|
|
|
@json_schema_type
|
|
class Int4QuantizationConfig(BaseModel):
|
|
type: Literal["int4"] = "int4"
|
|
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
|
|
|
|
|
QuantizationConfig = Annotated[
|
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class UserMessage(BaseModel):
|
|
role: Literal["user"] = "user"
|
|
content: InterleavedContent
|
|
context: Optional[InterleavedContent] = None
|
|
|
|
|
|
@json_schema_type
|
|
class SystemMessage(BaseModel):
|
|
role: Literal["system"] = "system"
|
|
content: InterleavedContent
|
|
|
|
|
|
@json_schema_type
|
|
class ToolResponseMessage(BaseModel):
|
|
role: Literal["ipython"] = "ipython"
|
|
# it was nice to re-use the ToolResponse type, but having all messages
|
|
# have a `content` type makes things nicer too
|
|
call_id: str
|
|
tool_name: Union[BuiltinTool, str]
|
|
content: InterleavedContent
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionMessage(BaseModel):
|
|
role: Literal["assistant"] = "assistant"
|
|
content: InterleavedContent
|
|
stop_reason: StopReason
|
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
|
|
|
|
|
Message = register_schema(
|
|
Annotated[
|
|
Union[
|
|
UserMessage,
|
|
SystemMessage,
|
|
ToolResponseMessage,
|
|
CompletionMessage,
|
|
],
|
|
Field(discriminator="role"),
|
|
],
|
|
name="Message",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class ToolResponse(BaseModel):
|
|
call_id: str
|
|
tool_name: Union[BuiltinTool, str]
|
|
content: InterleavedContent
|
|
|
|
@field_validator("tool_name", mode="before")
|
|
@classmethod
|
|
def validate_field(cls, v):
|
|
if isinstance(v, str):
|
|
try:
|
|
return BuiltinTool(v)
|
|
except ValueError:
|
|
return v
|
|
return v
|
|
|
|
|
|
@json_schema_type
|
|
class ToolChoice(Enum):
|
|
auto = "auto"
|
|
required = "required"
|
|
|
|
|
|
@json_schema_type
|
|
class TokenLogProbs(BaseModel):
|
|
logprobs_by_token: Dict[str, float]
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseEventType(Enum):
|
|
start = "start"
|
|
complete = "complete"
|
|
progress = "progress"
|
|
|
|
|
|
@json_schema_type
|
|
class ToolCallParseStatus(Enum):
|
|
started = "started"
|
|
in_progress = "in_progress"
|
|
failure = "failure"
|
|
success = "success"
|
|
|
|
|
|
@json_schema_type
|
|
class ToolCallDelta(BaseModel):
|
|
content: Union[str, ToolCall]
|
|
parse_status: ToolCallParseStatus
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseEvent(BaseModel):
|
|
"""Chat completion response event."""
|
|
|
|
event_type: ChatCompletionResponseEventType
|
|
delta: Union[str, ToolCallDelta]
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
stop_reason: Optional[StopReason] = None
|
|
|
|
|
|
class ResponseFormatType(Enum):
|
|
json_schema = "json_schema"
|
|
grammar = "grammar"
|
|
|
|
|
|
class JsonSchemaResponseFormat(BaseModel):
|
|
type: Literal[ResponseFormatType.json_schema.value] = (
|
|
ResponseFormatType.json_schema.value
|
|
)
|
|
json_schema: Dict[str, Any]
|
|
|
|
|
|
class GrammarResponseFormat(BaseModel):
|
|
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
|
bnf: Dict[str, Any]
|
|
|
|
|
|
ResponseFormat = register_schema(
|
|
Annotated[
|
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
|
Field(discriminator="type"),
|
|
],
|
|
name="ResponseFormat",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionRequest(BaseModel):
|
|
model: str
|
|
content: InterleavedContent
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
response_format: Optional[ResponseFormat] = None
|
|
|
|
stream: Optional[bool] = False
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionResponse(BaseModel):
|
|
"""Completion response."""
|
|
|
|
content: str
|
|
stop_reason: StopReason
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionResponseStreamChunk(BaseModel):
|
|
"""streamed completion response."""
|
|
|
|
delta: str
|
|
stop_reason: Optional[StopReason] = None
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchCompletionRequest(BaseModel):
|
|
model: str
|
|
content_batch: List[InterleavedContent]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
response_format: Optional[ResponseFormat] = None
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchCompletionResponse(BaseModel):
|
|
"""Batch completion response."""
|
|
|
|
batch: List[CompletionResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
default=ToolPromptFormat.json
|
|
)
|
|
response_format: Optional[ResponseFormat] = None
|
|
|
|
stream: Optional[bool] = False
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseStreamChunk(BaseModel):
|
|
"""SSE-stream of these events."""
|
|
|
|
event: ChatCompletionResponseEvent
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""Chat completion response."""
|
|
|
|
completion_message: CompletionMessage
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages_batch: List[List[Message]]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
default=ToolPromptFormat.json
|
|
)
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchChatCompletionResponse(BaseModel):
|
|
batch: List[ChatCompletionResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class EmbeddingsResponse(BaseModel):
|
|
embeddings: List[List[float]]
|
|
|
|
|
|
class ModelStore(Protocol):
|
|
def get_model(self, identifier: str) -> Model: ...
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class Inference(Protocol):
|
|
model_store: ModelStore
|
|
|
|
@webmethod(route="/inference/completion")
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
|
|
|
@webmethod(route="/inference/chat-completion")
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[
|
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
]: ...
|
|
|
|
@webmethod(route="/inference/embeddings")
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse: ...
|