mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 20:30:03 +00:00
Merge branch 'refs/heads/main' into preprocessors
# Conflicts: # llama_stack/distribution/routers/routers.py # llama_stack/templates/ollama/build.yaml # llama_stack/templates/ollama/run-with-safety.yaml # llama_stack/templates/ollama/run.yaml # llama_stack/templates/remote-vllm/build.yaml # llama_stack/templates/remote-vllm/run-with-safety.yaml # llama_stack/templates/remote-vllm/run.yaml # llama_stack/templates/together/build.yaml # llama_stack/templates/together/run-with-safety.yaml # llama_stack/templates/together/run.yaml
This commit is contained in:
commit
6b9f673fdb
313 changed files with 181388 additions and 7064 deletions
|
|
@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
|
|
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
|
|||
|
||||
|
||||
class StepType(Enum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
|
|
@ -66,6 +94,11 @@ class StepType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
|
|
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
|
|
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
|
|||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
violation: Optional[SafetyViolation]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
|
@ -148,7 +199,7 @@ AgentToolGroup = register_schema(
|
|||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
|
@ -296,16 +347,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||
allow_turn_resume: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponseMessage]
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
|
@ -338,7 +386,13 @@ class Agents(Protocol):
|
|||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
async def create_agent_turn(
|
||||
|
|
@ -355,8 +409,19 @@ class Agents(Protocol):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
"""
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
|
|
@ -367,7 +432,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
|
@ -378,6 +443,7 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
|
|
@ -392,7 +458,15 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn: ...
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
|
|
@ -404,14 +478,30 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse: ...
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse: ...
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||
async def get_agents_session(
|
||||
|
|
@ -419,17 +509,35 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session: ...
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Delete an agent by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
|
|
@ -50,7 +50,7 @@ class BatchInference(Protocol):
|
|||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,14 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class PaginatedRowsResult(BaseModel):
|
||||
"""
|
||||
A paginated list of rows from a dataset.
|
||||
|
||||
:param rows: The rows in the current page.
|
||||
:param total_count: The total number of rows in the dataset.
|
||||
:param next_page_token: The token to get the next page of rows.
|
||||
"""
|
||||
|
||||
# the rows obey the DatasetSchema for the given dataset
|
||||
rows: List[Dict[str, Any]]
|
||||
total_count: int
|
||||
|
|
@ -36,7 +44,15 @@ class DatasetIO(Protocol):
|
|||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
) -> PaginatedRowsResult:
|
||||
"""Get a paginated list of rows from a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param rows_in_page: The number of rows to get per page.
|
||||
:param page_token: The token to get the next page of rows.
|
||||
:param filter_condition: (Optional) A condition to filter the rows by.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/rows", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
|
@ -35,3 +38,20 @@ class Api(Enum):
|
|||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
|
||||
:param status: HTTP status code
|
||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||
:param detail: Error detail, a longer human-readable description of the error
|
||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||
"""
|
||||
|
||||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
|
|
@ -27,6 +34,11 @@ class ModelCandidate(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
|
@ -39,6 +51,13 @@ EvalCandidate = register_schema(
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
|
@ -53,18 +72,32 @@ class BenchmarkConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
|
|
@ -72,14 +105,41 @@ class Eval(Protocol):
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -278,14 +278,14 @@ ResponseFormat = register_schema(
|
|||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
|
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
|
|||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
|
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
@ -444,7 +444,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -467,7 +467,7 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
|
|||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
"""
|
||||
A scoring result for a single row.
|
||||
|
||||
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
||||
:param aggregated_results: Map of metric name to aggregated value
|
||||
"""
|
||||
|
||||
score_rows: List[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
|
|
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
"""
|
||||
The response from scoring.
|
||||
|
||||
:param results: A map of scoring function name to ScoringResult.
|
||||
"""
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
|
@ -55,4 +68,11 @@ class Scoring(Protocol):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
) -> ScoreResponse: ...
|
||||
) -> ScoreResponse:
|
||||
"""Score a list of rows.
|
||||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class ModelDescribe(Subcommand):
|
|||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
sampling_params = model.recommended_sampling_params.dict()
|
||||
sampling_params = model.recommended_sampling_params.model_dump()
|
||||
for k in ("max_tokens", "repetition_penalty"):
|
||||
del sampling_params[k]
|
||||
rows.append(
|
||||
|
|
|
|||
|
|
@ -7,10 +7,14 @@
|
|||
import argparse
|
||||
import textwrap
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||
|
|
@ -48,7 +52,26 @@ class ModelPromptFormat(Subcommand):
|
|||
supported_model_ids = [
|
||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
]
|
||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
model_str = "\n".join(model_list)
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
rows = []
|
||||
for m in model_list:
|
||||
rows.append(
|
||||
[
|
||||
m,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
|
|
@ -57,9 +80,9 @@ class ModelPromptFormat(Subcommand):
|
|||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
|
||||
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||
if model_family(model_id) == ModelFamily.llama3_1:
|
||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||
content = f.open("r").read()
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in available_providers,
|
||||
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
|
||||
error_message="Invalid provider, use <TAB> to see options",
|
||||
),
|
||||
)
|
||||
|
|
@ -248,7 +248,7 @@ def _generate_run_config(
|
|||
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
|
||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class StackBuild(Subcommand):
|
|||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. Defaults to 8321",
|
||||
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
@ -79,12 +79,8 @@ class StackRun(Subcommand):
|
|||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.build import ImageType
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
BUILDS_BASE_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||
|
||||
config_file = Path(args.config)
|
||||
|
|
@ -97,14 +93,6 @@ class StackRun(Subcommand):
|
|||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to conda dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to container dir
|
||||
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
|
|
|||
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(ValueError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
||||
|
|
@ -15,7 +15,6 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -103,8 +102,6 @@ def build_image(
|
|||
template_or_config,
|
||||
image_name,
|
||||
container_base,
|
||||
str(build_file_path),
|
||||
str(BUILDS_BASE_DIR / ImageType.container.value),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
|
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
|
|
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
|
|||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! command -v conda &>/dev/null; then
|
||||
if ! is_command_available conda; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
|
|||
# these packages are damaged in test-pypi, so install them first
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION \
|
||||
llama-stack-client==$TEST_PYPI_VERSION \
|
||||
llama-stack==$TEST_PYPI_VERSION \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
|
|
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
PYPI_VERSION="${PYPI_VERSION:-}"
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
uv pip install --no-cache-dir $SPEC_VERSION
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
uv pip uninstall llama-models
|
||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
|
||||
|
|
@ -20,26 +19,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
if [ "$#" -lt 6 ]; then
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
image_name="$2"
|
||||
container_base="$3"
|
||||
build_file_path="$4"
|
||||
host_build_dir="$5"
|
||||
pip_dependencies="$6"
|
||||
special_pip_deps="${7:-}"
|
||||
shift
|
||||
image_name="$1"
|
||||
shift
|
||||
container_base="$1"
|
||||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
special_pip_deps="${1:-}"
|
||||
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
|
|
@ -47,8 +47,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
|||
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
add_to_container() {
|
||||
local input
|
||||
output_file="$TEMP_DIR/Containerfile"
|
||||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
|
|
@ -58,15 +60,21 @@ add_to_container() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Update and install UBI9 components if UBI9 base image is used
|
||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
|
||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
|
||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN pip install uv
|
||||
|
|
@ -107,7 +115,6 @@ EOF
|
|||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
models_mount="/app/llama-models-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
install_local_package() {
|
||||
|
|
@ -131,10 +138,6 @@ EOF
|
|||
}
|
||||
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
|
@ -150,12 +153,12 @@ EOF
|
|||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
EOF
|
||||
else
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
|
|
@ -165,6 +168,11 @@ EOF
|
|||
fi
|
||||
fi
|
||||
|
||||
# remove uv after installation
|
||||
add_to_container << EOF
|
||||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||
if [[ "$template_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
|
|
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
|
|||
RUN chmod -R g+rw /app /.llama /.cache
|
||||
EOF
|
||||
|
||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
||||
cat $TEMP_DIR/Containerfile
|
||||
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
|
||||
cat "$TEMP_DIR"/Containerfile
|
||||
printf "\n"
|
||||
|
||||
mounts=""
|
||||
# Start building the CLI arguments
|
||||
CLI_ARGS=()
|
||||
|
||||
# Read CONTAINER_OPTS and put it in an array
|
||||
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
|
||||
fi
|
||||
fi
|
||||
|
||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||
if is_command_available selinuxenabled && selinuxenabled; then
|
||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
CLI_ARGS+=("--security-opt" "label=disable")
|
||||
fi
|
||||
|
||||
# Set version tag based on PyPI version
|
||||
|
|
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
|
|||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
|
||||
version_tag="dev"
|
||||
else
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
|
|
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
|
|||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
PLATFORM="--platform $BUILD_PLATFORM"
|
||||
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
PLATFORM="--platform linux/arm64"
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
PLATFORM="--platform linux/amd64"
|
||||
CLI_ARGS+=("--platform" "linux/amd64")
|
||||
else
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
|
|
@ -238,8 +248,13 @@ fi
|
|||
echo "PWD: $(pwd)"
|
||||
echo "Containerfile: $TEMP_DIR/Containerfile"
|
||||
set -x
|
||||
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
|
||||
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
|
||||
|
||||
$CONTAINER_BINARY build \
|
||||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"." \
|
||||
--progress=plain
|
||||
|
||||
# clean up tmp/configs
|
||||
set +x
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||
# the only difference is that we don't do any conda-specific setup
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
|
@ -21,8 +21,8 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
|||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 2 ]; then
|
||||
|
|
@ -95,7 +95,7 @@ run() {
|
|||
# we are building a command line so word splitting is expected
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
|
|
@ -120,15 +120,14 @@ run() {
|
|||
uv pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
||||
uv pip uninstall llama-models
|
||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
|
|||
return Provider(
|
||||
provider_id=provider.provider_id,
|
||||
provider_type=provider.provider_type,
|
||||
config=cfg.dict(),
|
||||
config=cfg.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: $0 <container name> <build file path>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
container_image="$1"
|
||||
host_build_dir="$2"
|
||||
container_build_dir="/app/builds"
|
||||
|
||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
||||
mounts=""
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
||||
fi
|
||||
|
||||
set -x
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
--entrypoint "/usr/local/bin/llama" \
|
||||
-v $host_build_dir:$container_build_dir \
|
||||
$mounts \
|
||||
$container_image \
|
||||
stack configure ./llamastack-build.yaml --output-dir $container_build_dir
|
||||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [v for v in Api]
|
||||
return list(Api)
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
|
|
@ -59,7 +59,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
logger.warning(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
)
|
||||
return value
|
||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -53,8 +53,6 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
|
@ -110,60 +108,43 @@ class ProviderWithSpec(Provider):
|
|||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
Resolves provider implementations by:
|
||||
1. Validating and organizing providers.
|
||||
2. Sorting them in dependency order.
|
||||
3. Instantiating them with required dependencies.
|
||||
"""
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
log.error(p.deprecation_error, "red", attrs=["bold"])
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
elif p.deprecation_warning:
|
||||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
run_config, provider_registry, routing_table_apis, router_apis
|
||||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Generates specifications for automatically routed APIs."""
|
||||
specs = {}
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
providers_with_specs[info.routing_table_api.value] = {
|
||||
specs[info.routing_table_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__routing_table__",
|
||||
provider_type="__routing_table__",
|
||||
|
|
@ -173,12 +154,12 @@ async def resolve_impls(
|
|||
router_api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
deps__=([f"inner-{info.router_api.value}"]),
|
||||
deps__=[f"inner-{info.router_api.value}"],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
provider_type="__autorouted__",
|
||||
|
|
@ -188,12 +169,69 @@ async def resolve_impls(
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
return specs
|
||||
|
||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||
providers_with_specs[key] = specs
|
||||
|
||||
return providers_with_specs
|
||||
|
||||
|
||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||
"""Validates if the provider is allowed and handles deprecations."""
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
|
|
@ -201,28 +239,31 @@ async def resolve_impls(
|
|||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={
|
||||
"run_config": run_config.dict(),
|
||||
},
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Resolved {len(sorted_providers)} providers")
|
||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
log.info(f" {api_str} => {provider.provider_id}")
|
||||
log.info("")
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
return sorted_providers
|
||||
|
||||
impls = {}
|
||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||
) -> Dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: Dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
|
@ -233,14 +274,9 @@ async def resolve_impls(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
deps,
|
||||
inner_impls,
|
||||
dist_registry,
|
||||
)
|
||||
# TODO: ugh slightly redesign this shady looking code
|
||||
if "inner-" in api_str:
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
else:
|
||||
api = Api(api_str)
|
||||
|
|
@ -251,7 +287,7 @@ async def resolve_impls(
|
|||
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[ProviderWithSpec]:
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
|
@ -267,8 +303,8 @@ def topological_sort(
|
|||
|
||||
stack.append(api_str)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
visited: Set[str] = set()
|
||||
stack: List[str] = []
|
||||
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
|
|
@ -278,13 +314,14 @@ def topological_sort(
|
|||
for api_str in stack:
|
||||
for provider in providers_with_specs[api_str]:
|
||||
flattened.append((api_str, provider))
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[str, Any],
|
||||
deps: Dict[Api, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
|
|
@ -292,8 +329,10 @@ async def instantiate_provider(
|
|||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
|
@ -356,7 +395,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
|
@ -69,9 +69,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"tool_runtime": ToolRuntimeRouter,
|
||||
"preprocessing": PreprocessingRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
|
@ -20,6 +22,10 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -27,13 +33,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.preprocessing import (
|
||||
Preprocessing,
|
||||
PreprocessingDataElement,
|
||||
|
|
@ -49,6 +56,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -59,8 +67,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.distribution.utils.chain import execute_preprocessor_chain
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
|
@ -70,12 +80,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
|
@ -86,6 +99,10 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||
)
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
|
@ -100,6 +117,10 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
|
|
@ -108,6 +129,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
|
@ -117,13 +139,21 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
|
@ -134,13 +164,68 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
) -> List[MetricEvent]:
|
||||
span = get_current_span()
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
|
|
@ -148,7 +233,13 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -167,8 +258,6 @@ class InferenceRouter(Inference):
|
|||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
|
|
@ -195,20 +284,63 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -223,10 +355,41 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -236,6 +399,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -255,12 +419,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
|
@ -270,6 +437,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
|
@ -278,6 +446,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
|
@ -290,12 +459,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
|
|
@ -305,6 +477,10 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=rows_in_page,
|
||||
|
|
@ -313,6 +489,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
|
@ -324,12 +501,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
|
@ -338,6 +518,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
|
@ -358,6 +539,10 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||
)
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
|
@ -375,22 +560,26 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def evaluate_rows(
|
||||
|
|
@ -398,13 +587,14 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
async def job_status(
|
||||
|
|
@ -412,6 +602,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
|
@ -419,6 +610,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -429,6 +621,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -441,6 +634,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
|
@ -449,6 +643,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
|
@ -460,6 +655,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
chunk_size_in_tokens: int = 512,
|
||||
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
|
||||
)
|
||||
|
|
@ -468,6 +667,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
|
@ -476,12 +676,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
|
@ -490,6 +693,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -318,13 +318,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
|
|
@ -375,7 +376,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
|
|
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
|
|||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logcat.init()
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
|
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
|
|||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
logcat.info("server", f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
|
|||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
|
|
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting up")
|
||||
logcat.info("server", "Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
logcat.info("server", "Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
|
@ -209,10 +209,11 @@ async def sse_generator(event_gen):
|
|||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
logcat.info("server", "Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
logcat.exception("server", f"Error in sse_generator: {e}")
|
||||
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
|
@ -234,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
logcat.exception("server", f"Error in {func.__name__}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
|
@ -313,6 +314,8 @@ class ClientVersionMiddleware:
|
|||
|
||||
|
||||
def main():
|
||||
logcat.init()
|
||||
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
|
@ -352,10 +355,10 @@ def main():
|
|||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
|
|
@ -363,12 +366,12 @@ def main():
|
|||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
logcat.info("server", f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
|
|
@ -376,9 +379,10 @@ def main():
|
|||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
logcat.info("server", "Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||
logcat.info("server", log_line)
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
|
@ -388,7 +392,7 @@ def main():
|
|||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
|
|
@ -433,11 +437,8 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
logger.info(f"Serving API {api_str}")
|
||||
for endpoint in endpoints:
|
||||
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||
|
||||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||
|
|
@ -463,10 +464,10 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
|
|
|
|||
|
|
@ -5,14 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
|
@ -35,14 +36,13 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
VectorDBs,
|
||||
|
|
@ -106,12 +106,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
log.info(
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
)
|
||||
|
||||
log.info("")
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
|
|
@ -160,18 +159,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
return result
|
||||
|
||||
elif isinstance(config, str):
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
||||
# Updated pattern to support both default values (:) and conditional values (+)
|
||||
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
|
||||
|
||||
def get_env_var(match):
|
||||
env_var = match.group(1)
|
||||
default_val = match.group(2)
|
||||
operator = match.group(2) # ':' for default, '+' for conditional
|
||||
value_expr = match.group(3)
|
||||
|
||||
value = os.environ.get(env_var)
|
||||
if not value:
|
||||
if default_val is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
env_value = os.environ.get(env_var)
|
||||
|
||||
if operator == ":": # Default value syntax: ${env.FOO:default}
|
||||
if not env_value:
|
||||
if value_expr is None:
|
||||
raise EnvVarError(env_var, path)
|
||||
else:
|
||||
value = value_expr
|
||||
else:
|
||||
value = default_val
|
||||
value = env_value
|
||||
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
|
||||
if env_value:
|
||||
value = value_expr
|
||||
else:
|
||||
# If env var is not set, return empty string for the conditional case
|
||||
value = ""
|
||||
else: # No operator case: ${env.FOO}
|
||||
if not env_value:
|
||||
raise EnvVarError(env_var, path)
|
||||
value = env_value
|
||||
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
|
@ -220,3 +235,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
|||
run_config = yaml.safe_load(path.open())
|
||||
|
||||
return StackRunConfig(**replace_env_vars(run_config))
|
||||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> StackRunConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||
"""
|
||||
|
||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||
provider_registry = provider_registry or get_provider_registry()
|
||||
|
||||
distro_dir = tempfile.mkdtemp()
|
||||
provider_configs_by_api = {}
|
||||
for api_provider in api_providers:
|
||||
api_str, provider = api_provider.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
providers_by_type = provider_registry[api]
|
||||
provider_spec = providers_by_type.get(provider)
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||
|
||||
if not provider_spec:
|
||||
raise ValueError(
|
||||
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||
)
|
||||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=provider,
|
||||
provider_type=provider_spec.provider_type,
|
||||
config=provider_config,
|
||||
)
|
||||
]
|
||||
config = StackRunConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
)
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -98,15 +98,20 @@ case "$env_type" in
|
|||
*)
|
||||
esac
|
||||
|
||||
set -x
|
||||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
|
|
@ -136,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
venv_path="$1"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
# Initialize env_vars as an empty array
|
||||
env_vars=""
|
||||
other_args=""
|
||||
# Process environment variables from --env arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Using virtual environment: $venv_path"
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$venv_path" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source "$venv_path/bin/activate"
|
||||
|
||||
set -x
|
||||
python -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||
if os.path.exists(config.db_path):
|
||||
os.remove(config.db_path)
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def registry(config):
|
||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def cached_registry(config):
|
||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return Model(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_initialization(registry):
|
||||
# Test empty registry
|
||||
result = await registry.get("nonexistent", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_registration(registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await registry.register(sample_vector_db)
|
||||
print(f"Registering {sample_model}")
|
||||
await registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
result_model = await registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
assert result_model.identifier == sample_model.identifier
|
||||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_updates(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(new_vector_db)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_provider_registration(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(original_vector_db)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_registry.register(duplicate_vector_db)
|
||||
|
||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_objects(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_registry.register(vector_db)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
|
|
@ -17,7 +17,7 @@ llama stack run together
|
|||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
$ llama-stack-client datasets register \
|
||||
llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
|
|
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
|
|||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client benchmarks register \
|
||||
llama-stack-client benchmarks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ def run_evaluation_3():
|
|||
benchmark_id=selected_benchmark,
|
||||
input_rows=[r],
|
||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||
task_config=benchmark_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
|
@ -124,26 +123,22 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
toolgroups=[
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
agent = Agent(llama_stack_api.client, agent_config)
|
||||
session_id = agent.create_session("rag-session")
|
||||
|
||||
# Chat input
|
||||
|
|
|
|||
|
|
@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
|
|
|||
204
llama_stack/logcat.py
Normal file
204
llama_stack/logcat.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Category-based logging utility for llama-stack.
|
||||
|
||||
This module provides a wrapper over the standard Python logging module that supports
|
||||
categorized logging with environment variable control.
|
||||
|
||||
Usage:
|
||||
from llama_stack import logcat
|
||||
logcat.info("server", "Starting up...")
|
||||
logcat.debug("inference", "Processing request...")
|
||||
|
||||
Environment variable:
|
||||
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||
Example: "server=debug;inference=warning"
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
COLORS = {
|
||||
"RESET": "\033[0m",
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"DIM": "\033[2m", # Dimmed text
|
||||
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||
}
|
||||
|
||||
# Static list of valid categories representing various parts of the Llama Stack
|
||||
# server codebase
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger("llama_stack")
|
||||
_logger.propagate = False
|
||||
|
||||
_default_level = logging.INFO
|
||||
|
||||
# Category-level mapping (can be modified by environment variables)
|
||||
_category_levels: Dict[str, int] = {}
|
||||
|
||||
|
||||
class TerminalStreamHandler(logging.StreamHandler):
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||
|
||||
def format(self, record):
|
||||
record.is_tty = self.is_tty
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter with colors and fixed-width level names"""
|
||||
|
||||
def format(self, record):
|
||||
levelname = record.levelname
|
||||
# Use only time with milliseconds, not date
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||
|
||||
file_info = f"{record.filename}:{record.lineno}"
|
||||
|
||||
# Get category from extra if available
|
||||
category = getattr(record, "category", None)
|
||||
msg = record.getMessage()
|
||||
|
||||
if getattr(record, "is_tty", False):
|
||||
color = COLORS.get(levelname, COLORS["RESET"])
|
||||
if category:
|
||||
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||
f"{file_info:<20} {category_formatted}{msg}"
|
||||
)
|
||||
else:
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||
f"{file_info:<20} {msg}"
|
||||
)
|
||||
else:
|
||||
if category:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||
else:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||
|
||||
return formatted_msg
|
||||
|
||||
|
||||
def init(default_level: int = logging.INFO) -> None:
|
||||
global _default_level, _category_levels, _logger
|
||||
|
||||
_default_level = default_level
|
||||
|
||||
_logger.setLevel(logging.DEBUG)
|
||||
_logger.handlers = [] # Clear existing handlers
|
||||
|
||||
# Add our custom handler with the colored formatter
|
||||
handler = TerminalStreamHandler()
|
||||
formatter = ColoredFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
_logger.addHandler(handler)
|
||||
|
||||
for category in CATEGORIES:
|
||||
_category_levels[category] = default_level
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().lower()
|
||||
|
||||
level_value = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}.get(level)
|
||||
|
||||
if level_value is None:
|
||||
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
for cat in CATEGORIES:
|
||||
_category_levels[cat] = level_value
|
||||
else:
|
||||
if category in CATEGORIES:
|
||||
_category_levels[category] = level_value
|
||||
else:
|
||||
_logger.warning(f"Unknown logging category: {category}")
|
||||
|
||||
except ValueError:
|
||||
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||
|
||||
|
||||
def _should_log(level: int, category: str) -> bool:
|
||||
category = category.lower()
|
||||
if category not in _category_levels:
|
||||
return False
|
||||
category_level = _category_levels[category]
|
||||
return level >= category_level
|
||||
|
||||
|
||||
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(level, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||
|
||||
|
||||
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||
warning(category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(logging.ERROR, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
||||
|
|
@ -11,16 +11,128 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# import all for backwards compatibility
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
system = "system"
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
tool = "tool"
|
||||
|
||||
|
||||
class BuiltinTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
Primitive = Union[str, int, float, bool, None]
|
||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
arguments: Dict[str, RecursiveType]
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class ToolPromptFormat(Enum):
|
||||
"""Prompt format for calling custom / zero shot tools.
|
||||
|
||||
:cvar json: JSON format for calling tools. It takes the form:
|
||||
{
|
||||
"type": "function",
|
||||
"function" : {
|
||||
"name": "function_name",
|
||||
"description": "function_description",
|
||||
"parameters": {...}
|
||||
}
|
||||
}
|
||||
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||
<function=function_name>(parameters)</function>
|
||||
|
||||
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||
evaluated to a list. Each element in the list is a function call. Example:
|
||||
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||
"""
|
||||
|
||||
json = "json"
|
||||
function_tag = "function_tag"
|
||||
python_list = "python_list"
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, v):
|
||||
if isinstance(v, str):
|
||||
return base64.b64decode(v)
|
||||
return v
|
||||
|
||||
|
||||
class RawTextItem(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||
|
||||
RawContent = str | RawContentItem | List[RawContentItem]
|
||||
|
||||
|
||||
class RawMessage(BaseModel):
|
||||
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||
content: RawContent
|
||||
|
||||
# This is for RAG but likely should be absorbed into content
|
||||
context: Optional[RawContent] = None
|
||||
|
||||
# These are for the output message coming from the assistant
|
||||
stop_reason: Optional[StopReason] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
|
||||
|
||||
|
|
|
|||
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionInput:
|
||||
mask: List[List[int]]
|
||||
images: List[PIL_Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
tokens: List[int]
|
||||
vision: Optional[VisionInput] = None
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
role_strs = {
|
||||
Role.user: "user",
|
||||
Role.system: "system",
|
||||
Role.tool: "ipython", # special
|
||||
Role.assistant: "assistant",
|
||||
}
|
||||
return role_strs[role]
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
|
||||
def __init__(self, tokenizer: Tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_content(self, content: RawContent) -> LLMInput:
|
||||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = []
|
||||
images = []
|
||||
|
||||
added_bos = False
|
||||
|
||||
def _process(c):
|
||||
nonlocal added_bos, bos
|
||||
|
||||
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||
if isinstance(c, RawTextItem):
|
||||
c = c.text
|
||||
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||
added_bos = True
|
||||
|
||||
elif isinstance(c, RawMediaItem):
|
||||
bos = False if added_bos else bos
|
||||
if bos:
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
added_bos = True
|
||||
tokens.append(self.vision_token)
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = image.convert("RGB")
|
||||
images.append(image)
|
||||
|
||||
if isinstance(content, list):
|
||||
for c in content:
|
||||
_process(c)
|
||||
else:
|
||||
_process(content)
|
||||
|
||||
return tokens, images
|
||||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
def _process_content(c):
|
||||
toks, imgs = self._encode_content(c)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
if (
|
||||
message.role == "assistant"
|
||||
and len(message.tool_calls) > 0
|
||||
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||
):
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||
|
||||
_process_content(message.content)
|
||||
|
||||
if message.role == "user" and message.context is not None:
|
||||
# This is RAG context; why is it here in the chat format? I don't think
|
||||
# this is needed and can be moved upwards
|
||||
_process_content("\n\n")
|
||||
_process_content(message.context)
|
||||
|
||||
if message.role == "assistant":
|
||||
for t in message.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||
return tokens, images
|
||||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> LLMInput:
|
||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||
tokens = []
|
||||
images = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
for message in messages:
|
||||
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
tokens.extend(self._encode_header("assistant"))
|
||||
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||
content = content.strip(" ")
|
||||
header_str = self.possible_headers[Role.assistant]
|
||||
if content.startswith(header_str):
|
||||
content = content[len(header_str) :]
|
||||
|
||||
ipython = content.startswith("<|python_tag|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_tag|>") :]
|
||||
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||
vision_input = None
|
||||
if len(images) > 0:
|
||||
vision_input = VisionInput(
|
||||
mask=create_vision_mask(tokens, self.vision_token),
|
||||
images=images,
|
||||
)
|
||||
|
||||
return LLMInput(
|
||||
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||
vision=vision_input,
|
||||
)
|
||||
|
||||
|
||||
def create_vision_mask(
|
||||
tokens: List[int],
|
||||
vision_token: int,
|
||||
) -> List[List[int]]:
|
||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||
if len(vision_token_locations) == 0:
|
||||
return []
|
||||
|
||||
if len(vision_token_locations) == 1:
|
||||
# only one image present, unmask until end of sequence
|
||||
return [[vision_token_locations[0], -1]]
|
||||
vision_masks = [
|
||||
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||
]
|
||||
# last image will attend to all subsequent text
|
||||
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||
|
||||
# if there are two or more consecutive vision tokens,
|
||||
# they should all attend to all subsequent
|
||||
# text present
|
||||
last_mask_end = vision_masks[-1][1]
|
||||
for vision_mask in vision_masks[::-1]:
|
||||
if vision_mask[0] == vision_mask[1] - 1:
|
||||
vision_mask[1] = last_mask_end
|
||||
last_mask_end = vision_mask[1]
|
||||
return vision_masks
|
||||
|
|
@ -14,20 +14,19 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
|
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,8 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
|
|||
|
|
@ -1,199 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|step_id|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
"<|image|>",
|
||||
]
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom_id|>"],
|
||||
self.special_tokens["<|eot_id|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
||||
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
||||
def is_json(s):
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
# Return True for valid objects and not for ints, strings, etc
|
||||
return isinstance(parsed, dict)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_python_list(input_string):
|
||||
"""Check if the input string is a valid Python list of function calls"""
|
||||
try:
|
||||
# Try to parse the string
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Check if it's a single expression
|
||||
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||
return False
|
||||
|
||||
# Check if the expression is a list
|
||||
expr = tree.body[0].value
|
||||
if not isinstance(expr, ast.List):
|
||||
return False
|
||||
|
||||
# Check if the list is empty
|
||||
if len(expr.elts) == 0:
|
||||
return False
|
||||
|
||||
# Check if all elements in the list are function calls
|
||||
for element in expr.elts:
|
||||
if not isinstance(element, ast.Call):
|
||||
return False
|
||||
|
||||
# Check if the function call has a valid name
|
||||
if not isinstance(element.func, ast.Name):
|
||||
return False
|
||||
|
||||
# Check if all arguments are keyword arguments
|
||||
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except SyntaxError:
|
||||
# If parsing fails, it's not a valid Python expression
|
||||
return False
|
||||
|
||||
|
||||
def parse_python_list_for_function_calls(input_string):
|
||||
"""
|
||||
Parse a Python list of function calls and
|
||||
return a list of tuples containing the function name and arguments
|
||||
"""
|
||||
# Parse the string into an AST
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Ensure the input is a list
|
||||
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||
raise ValueError("Input must be a list of function calls")
|
||||
|
||||
result = []
|
||||
|
||||
# Iterate through each function call in the list
|
||||
for node in tree.body[0].value.elts:
|
||||
if isinstance(node, ast.Call):
|
||||
function_name = node.func.id
|
||||
function_args = {}
|
||||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ToolUtils:
|
||||
@staticmethod
|
||||
def is_builtin_tool_call(message_body: str) -> bool:
|
||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# Find the first match in the text
|
||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||
|
||||
# Check if a match is found and return it
|
||||
if match:
|
||||
tool_name = match.group("tool_name")
|
||||
query = match.group("query")
|
||||
return tool_name, query
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# NOTE: Custom function too calls are still experimental
|
||||
# Sometimes, response is of the form
|
||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||
# and some times
|
||||
# <function=function_name>(parameters)</function>
|
||||
|
||||
# Find the first match in the text
|
||||
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||
if match:
|
||||
tool_name = match.group("function_name")
|
||||
query = match.group("args")
|
||||
try:
|
||||
return tool_name, json.loads(query.replace("'", '"'))
|
||||
except Exception as e:
|
||||
print("Exception while parsing json query for custom tool call", query, e)
|
||||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
else:
|
||||
return None
|
||||
elif is_valid_python_list(message_body):
|
||||
res = parse_python_list_for_function_calls(message_body)
|
||||
# FIXME: Enable multiple tool calls
|
||||
return res[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
if tool_prompt_format == ToolPromptFormat.json:
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
def format_value(value: RecursiveType) -> str:
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
elif isinstance(value, (int, float, bool)) or value is None:
|
||||
return str(value)
|
||||
elif isinstance(value, list):
|
||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||
elif isinstance(value, dict):
|
||||
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
|
||||
|
||||
# Llama 3.1 - Prompt Formats
|
||||
## Tokens
|
||||
Here is a list of special tokens that are supported by Llama 3.1:
|
||||
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
|
||||
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||
- at the end of a direct interaction between the model and the user
|
||||
- at the end of multiple interactions between the model and any available tools
|
||||
This token signals to the executor that the model has finished generating a response.
|
||||
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||
|
||||
|
||||
|
||||
There are 4 different roles that are supported by Llama 3.1
|
||||
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
|
||||
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
|
||||
|
||||
## Llama 3.1 Base Model
|
||||
|
||||
Text completion for Llama 3.1 base model uses this format.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>Color of sky is blue but sometimes can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
|
||||
The color of the sky is primarily blue because of a phenomenon called
|
||||
```
|
||||
|
||||
|
||||
|
||||
Note start special tag
|
||||
|
||||
|
||||
## Llama 3.1 Instruct Model
|
||||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
Here's my response
|
||||
|
||||
"What is a helpful assistant?"<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Tool Calling Formats
|
||||
|
||||
|
||||
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||
- Brave Search: Tool call to perform web searches.
|
||||
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||
- Code Interpreter: Enables the model to output python code.
|
||||
|
||||
## Builtin Tool Calling
|
||||
|
||||
|
||||
Here is an example of a conversation using brave search
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||
|
||||
|
||||
## Builtin Code Interpreter
|
||||
|
||||
Here is an actual example of model responding with code
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>def is_prime(n):
|
||||
if n <= 1
|
||||
return False
|
||||
for i in range(2, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
print(is_prime(7)) # Output: True<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||
|
||||
|
||||
## Built-in tools full interaction
|
||||
|
||||
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
|
||||
{
|
||||
"queryresult": {
|
||||
"success": true,
|
||||
"inputstring": "100th decimal of pi",
|
||||
"pods": [
|
||||
{
|
||||
"title": "Input interpretation",
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "100th digit | π"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Nearby digits",
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Result",
|
||||
"primary": true,
|
||||
"subpods": [
|
||||
{
|
||||
"title": "",
|
||||
"plaintext": "7"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The 100th decimal of pi is 7.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- Note the `<|python_tag|>` in the assistant response.
|
||||
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
|
||||
- Final message from assistant has <|eot_id|> tag.
|
||||
|
||||
|
||||
|
||||
## Zero shot tool calling
|
||||
## JSON based tool calling
|
||||
|
||||
|
||||
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>{
|
||||
"type": "function",
|
||||
"name": "trending_songs",
|
||||
"parameters": {
|
||||
"n": "10",
|
||||
"genre": "all"
|
||||
}
|
||||
}<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- JSON format for providing tools needs name, description and parameters
|
||||
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||
- Instructions for tools added as a user message
|
||||
- Only single tool calls are supported as of now
|
||||
|
||||
|
||||
|
||||
## Example of a user defined tool calling
|
||||
## `<function>` based tool calling
|
||||
|
||||
|
||||
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 21 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<function=trending_songs>{"n": 10}</function><|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||
- Instructions for tools added as a user message
|
||||
|
||||
|
||||
Thank You!
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
|
|
|
|||
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
|
||||
|
||||
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
|
||||
|
||||
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
|
||||
|
||||
So, what can I help you with today?<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
This format is unchanged from Llama3.1
|
||||
|
||||
## Zero shot function calling
|
||||
|
||||
|
||||
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||
This new format is designed to be more flexible and powerful than the previous format.
|
||||
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||
Here is an example for the same,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The output supports multiple tool calls natively
|
||||
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||
|
||||
|
||||
## Zero shot function calling with user message
|
||||
|
||||
|
||||
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{
|
||||
"name": "get_user_info",
|
||||
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||
},
|
||||
"special": {
|
||||
"type": "string",
|
||||
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||
"default": "none"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||
|
||||
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
[get_user_info(user_id=7890, special='black')]<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||
|
||||
|
||||
## Code Interpreter
|
||||
|
||||
|
||||
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||
Here is an example,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 24 September 2024
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
if n == 2:
|
||||
return True
|
||||
if n % 2 == 0:
|
||||
return False
|
||||
max_divisor = int(n**0.5) + 1
|
||||
for d in range(3, max_divisor, 2):
|
||||
if n % d == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
print(is_prime(7)) # Output: True<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- Note `Environment: ipython` in the system prompt.
|
||||
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||
|
||||
|
||||
## Zero shot function calling E2E format
|
||||
|
||||
|
||||
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The weather in San Francisco is 25 C.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||
|
||||
|
||||
## Prompt format for base models
|
||||
|
||||
|
||||
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
|
||||
|
||||
## Step 1: Identify the factors that
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
Same as Llama3.1
|
||||
|
||||
Thank You!
|
||||
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
## User and assistant conversation
|
||||
|
||||
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
|
||||
|
||||
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
|
||||
|
||||
What can I help you with today?<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
This format is unchanged from Llama3.1
|
||||
|
||||
## User and assistant conversation with Images
|
||||
|
||||
This example shows how to pass and image to the model as part of the messages.
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- The `<|image|>` tag is used to indicate presence of the image
|
||||
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||

|
||||
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||
- We recommend using a single image in one prompt
|
||||
|
||||
|
||||
## Builtin and Zero Shot Tool Calling
|
||||
|
||||
|
||||
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||
Use `Environment: ipython` to enable tools.
|
||||
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||
The same builtin tools as Llama3.1 are available,
|
||||
- code_interpreter (for executing python code)
|
||||
- brave_search (to search the web)
|
||||
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: 23 September 2024
|
||||
|
||||
You are a helpful assistant.
|
||||
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
|
||||
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||
- Tool Calling does NOT work with images in the prompt as of now.
|
||||
|
||||
|
||||
## Prompt format for base models
|
||||
|
||||
|
||||
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
- Same as Llama3.1
|
||||
|
||||
## Prompt format for base models with Image
|
||||
|
||||
|
||||
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||
|
||||
|
||||
##### Input Prompt Format
|
||||
```
|
||||
<|begin_of_text|><|image|>If I had to write a haiku for this one
|
||||
```
|
||||
|
||||
##### Model Response Format
|
||||
```
|
||||
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
|
||||
```
|
||||
|
||||
|
||||
##### Notes
|
||||
- Note the placement of the special tags <|begin_of_text|> and <|image|>
|
||||
|
||||
Thank You!
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ import textwrap
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
|
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
|
|
|||
|
|
@ -6,18 +6,18 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResumeRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
|
|
@ -79,8 +78,6 @@ from llama_stack.providers.utils.telemetry import tracing
|
|||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
|
@ -186,115 +183,61 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls and request.allow_turn_resume:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.tool_responses)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
if is_resume and len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
||||
tool_response_messages = request.tool_responses
|
||||
tool_responses = [
|
||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
else:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
]
|
||||
tool_responses = request.tool_responses
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
last_turn_messages.extend(tool_response_messages)
|
||||
|
||||
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||
last_turn_messages.extend(request.tool_responses)
|
||||
|
||||
# get the steps from the turn id
|
||||
steps = []
|
||||
steps = turns[-1].steps
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
|
|
@ -307,14 +250,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in request.tool_responses
|
||||
],
|
||||
tool_responses=tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
|
@ -328,62 +264,67 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn_messages
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
last_turn_start_time = datetime.now().astimezone().isoformat()
|
||||
if len(turns) > 0:
|
||||
last_turn_start_time = turns[-1].started_at
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=last_turn_messages,
|
||||
output_message=output_message,
|
||||
started_at=last_turn_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -533,9 +474,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
|
|
@ -622,6 +572,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
||||
n_iter += 1
|
||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
# If tool calls are parsed successfully,
|
||||
|
|
@ -656,12 +609,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
log.info("Done with MAX iterations, exiting.")
|
||||
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
yield message
|
||||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
log.info("Out of token budget, exiting.")
|
||||
logcat.info("agents", "out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
|
@ -675,10 +631,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||
)
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||
)
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -706,6 +668,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# NOTE: mark end_of_message to indicate to client that it may
|
||||
# call the tool and continue the conversation with the tool's response.
|
||||
message.stop_reason = StopReason.end_of_message
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
|
|
@ -791,24 +756,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
for toolgroup in tool_groups_to_include:
|
||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
|
|
@ -831,9 +788,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
|
|
@ -1029,7 +983,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
log.info(f"Downloading {url} -> {filepath}")
|
||||
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
|
@ -1069,6 +1023,7 @@ async def execute_tool_call_maybe(
|
|||
else:
|
||||
name = name.value
|
||||
|
||||
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
|
|
@ -1078,6 +1033,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from llama_stack.apis.agents import (
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -140,7 +141,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
|
@ -150,7 +150,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
allow_turn_resume=allow_turn_resume,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
@ -170,7 +169,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
|
|||
|
|
@ -105,3 +105,15 @@ class AgentPersistence:
|
|||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
|
|
|||
|
|
@ -1,400 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolHost,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="AI is a fascinating field...",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
delta="",
|
||||
stop_reason="end_of_turn",
|
||||
)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_response()
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant",
|
||||
content="Mock response",
|
||||
stop_reason="end_of_turn",
|
||||
),
|
||||
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
class MockSafetyAPI:
|
||||
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
||||
class MockVectorIOAPI:
|
||||
def __init__(self):
|
||||
self.chunks = {}
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
|
||||
for chunk in chunks:
|
||||
metadata = chunk.metadata
|
||||
self.chunks[vector_db_id][metadata["document_id"]] = chunk
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params=None):
|
||||
if vector_db_id not in self.chunks:
|
||||
raise ValueError(f"Bank {vector_db_id} not found")
|
||||
|
||||
chunks = list(self.chunks[vector_db_id].values())
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MockToolGroupsAPI:
|
||||
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return ToolGroup(
|
||||
identifier=toolgroup_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::rag",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
identifier=tool_name,
|
||||
provider_resource_id=tool_name,
|
||||
toolgroup_id="mock_group",
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="mock_provider",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockToolRuntimeAPI:
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
return MockSafetyAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_io_api():
|
||||
return MockVectorIOAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_groups_api():
|
||||
return MockToolGroupsAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
return MockToolRuntimeAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_agents_impl(
|
||||
mock_inference_api,
|
||||
mock_safety_api,
|
||||
mock_vector_io_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_tool_groups_api,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_name=sqlite_file.name,
|
||||
),
|
||||
),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent(get_agents_impl):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
MEMORY_TOOLGROUP = "builtin::rag"
|
||||
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||
impl = await get_agents_impl
|
||||
toolgroups = request.param
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
assert (
|
||||
len(responses) == 7
|
||||
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
async for chunk in chat_agent.run_multiple_shields_wrapper(
|
||||
turn_id="test_turn_id",
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
touchpoint="user-input",
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
assert len(responses) > 0
|
||||
|
||||
step_types = [
|
||||
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||
assert StepType.inference in step_types, "Inference step is missing"
|
||||
|
||||
event_types = [
|
||||
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "turn_start" in event_types, "Start event is missing"
|
||||
assert "turn_complete" in event_types, "Complete event is missing"
|
||||
|
||||
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
|
||||
"Turn complete event is missing"
|
||||
)
|
||||
turn_complete_payload = next(
|
||||
response.event.payload
|
||||
for response in responses
|
||||
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
)
|
||||
turn = turn_complete_payload.turn
|
||||
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroups, expected_memory, expected_code_interpreter",
|
||||
[
|
||||
([], False, False), # no tools
|
||||
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||
],
|
||||
)
|
||||
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
toolgroups_for_turn=[
|
||||
AgentToolGroupWithArgs(
|
||||
name=MEMORY_TOOLGROUP,
|
||||
args={"vector_dbs": ["test_vector_db"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
@ -82,23 +83,22 @@ class MetaReferenceEvalImpl(
|
|||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
|
||||
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
|
@ -108,16 +108,16 @@ class MetaReferenceEvalImpl(
|
|||
return Job(job_id=job_id)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
|
|
@ -151,15 +151,15 @@ class MetaReferenceEvalImpl(
|
|||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
|
|
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = eval(chat_completion_input_str)
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
|
|
@ -190,13 +189,13 @@ class MetaReferenceEvalImpl(
|
|||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, task_config)
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
elif candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, task_config)
|
||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||
|
||||
|
|
@ -205,9 +204,9 @@ class MetaReferenceEvalImpl(
|
|||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
]
|
||||
|
||||
if task_config.scoring_params is not None:
|
||||
if benchmark_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .llama3.generation import Llama3
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
|
|||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
|
@ -111,7 +111,7 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
if llama_model is None:
|
||||
raise ValueError(
|
||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
||||
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
|
|
@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
@ -208,7 +210,6 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
@ -245,7 +246,7 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -254,6 +255,8 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "scheme":
|
||||
setattr(self, k, QuantizationScheme(v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAArgs:
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# vision model params
|
||||
vision_chunk_size: int = -1 # image resolution for image models
|
||||
vision_max_num_chunks: int = 4
|
||||
vision_num_cross_attention_layers: int = -1
|
||||
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "lora_args":
|
||||
setattr(self, k, LoRAArgs(**v))
|
||||
elif k == "quantization_args":
|
||||
setattr(self, k, QuantizationArgs(**v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
if self.n_kv_heads is None:
|
||||
self.n_kv_heads = self.n_heads
|
||||
assert self.n_kv_heads <= self.n_heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.dim % self.n_heads == 0
|
||||
|
|
@ -23,15 +23,7 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
|
|
@ -39,46 +31,30 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from ..common import TokenResult, model_checkpoint_dir
|
||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .args import ModelArgs
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Llama:
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
|
|
@ -170,7 +146,7 @@ class Llama:
|
|||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
from ..quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
|
|
@ -183,7 +159,7 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
from ..quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
|
|
@ -193,7 +169,7 @@ class Llama:
|
|||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from .quantization.hadamard_utils import (
|
||||
from ..quantization.hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
|
|
@ -222,7 +198,7 @@ class Llama:
|
|||
model.to(device)
|
||||
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args, llama_model_id)
|
||||
return Llama3(model, tokenizer, model_args, llama_model_id)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
wavelen = 2 * torch.pi / freqs
|
||||
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
return torch.where(
|
||||
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||
new_freqs,
|
||||
)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if mask is not None:
|
||||
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, params: ModelArgs):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.vocab_size = params.vocab_size
|
||||
self.n_layers = params.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(params.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, params))
|
||||
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
params.dim // params.n_heads,
|
||||
params.max_seq_len * 2,
|
||||
params.rope_theta,
|
||||
params.use_scaled_rope,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if mask.device.type == torch.device("mps").type:
|
||||
mask = torch.nan_to_num(mask, nan=0.0)
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
return output
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
"""
|
||||
Resize position embedding for vision encoder.
|
||||
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||
"""
|
||||
new_grid_size = to_2tuple(grid_size)
|
||||
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||
|
||||
new_pos_emb_tok, new_pos_emb_img = (
|
||||
orig_pos_embed[:1],
|
||||
orig_pos_embed[1:],
|
||||
)
|
||||
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||
return new_pos_embed
|
||||
|
||||
|
||||
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a local position embedding for vision encoder and uses it
|
||||
to initialize the global position embedding.
|
||||
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
pos_embed = pos_and_cls_embed[1:]
|
||||
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||
grid_size = to_2tuple(grid_size)
|
||||
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
# first remove cls token
|
||||
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||
|
||||
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||
old_grid_size = int(math.sqrt(ntok))
|
||||
|
||||
# move to correct form for interpolation
|
||||
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||
pos_embed = pos_embed.unsqueeze(0)
|
||||
|
||||
# interpolate
|
||||
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||
pos_embed_resized = F.interpolate(
|
||||
pos_embed,
|
||||
size=new_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||
|
||||
# move it back in place
|
||||
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||
|
||||
# interpolate cls token
|
||||
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||
cls_embed_resized = F.interpolate(
|
||||
cls_embed,
|
||||
size=(x_scale, y_scale),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||
# add cls token back in
|
||||
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def build_encoder_attention_mask(
|
||||
x: torch.Tensor,
|
||||
ar: torch.Tensor,
|
||||
ntok: int,
|
||||
num_chunks: int,
|
||||
n_heads: int,
|
||||
):
|
||||
"""
|
||||
Build vision encoder attention mask that omits padding tokens.
|
||||
"""
|
||||
masks = []
|
||||
for arx in ar:
|
||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||
mask_i = mask_i.unsqueeze(0)
|
||||
masks.append(mask_i)
|
||||
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
return masks
|
||||
|
||||
|
||||
def expand_num_tokens_to_mult8(x):
|
||||
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||
if num_pad_tokens == 0:
|
||||
return x, 0
|
||||
else:
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
x,
|
||||
torch.zeros(
|
||||
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
),
|
||||
],
|
||||
dim=-2,
|
||||
),
|
||||
num_pad_tokens,
|
||||
)
|
||||
|
||||
|
||||
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||
if num_pad_tokens == 0:
|
||||
return x
|
||||
return x[:, :, :-num_pad_tokens]
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, depth in value:
|
||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_negative_inf_value(dtype):
|
||||
return torch.finfo(dtype).min
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
|
@ -9,18 +9,18 @@ from copy import deepcopy
|
|||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .llama3.generation import Llama3
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ def init_model_cb(
|
|||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
llama = Llama.build(config, model_id, llama_model)
|
||||
llama = Llama3.build(config, model_id, llama_model)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .generation import TokenResult
|
||||
from .common import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
return parse_message(maybe_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
|
|||
if isinstance(obj, TaskResponse):
|
||||
yield obj.result
|
||||
|
||||
except GeneratorExit as e:
|
||||
except GeneratorExit:
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# The file gets a special treatment for now?
|
||||
# ruff: noqa: N803
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ import torch
|
|||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
|
|
@ -24,6 +22,8 @@ from llama_stack.apis.inference import QuantizationType
|
|||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
from ...llama3.args import ModelArgs
|
||||
from ...llama3.model import Transformer, TransformerBlock
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||
quantize_fp8,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ NPROC=$7
|
|||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import os
|
|||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
|
|
@ -36,6 +35,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
|
|
@ -143,7 +143,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -154,7 +154,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -163,6 +163,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
|
|
|||
|
|
@ -10,16 +10,19 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||
"Invalid input row"
|
||||
)
|
||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
|
||||
|
||||
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
|
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
|
|||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
dialog = eval(str(sample[ColumnName.dialog.value]))
|
||||
dialog = json.loads(sample[ColumnName.dialog.value])
|
||||
|
||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||
roles = []
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
self.adapter_params = get_adapter_params(model)
|
||||
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
||||
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
|
||||
|
||||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"The best answer is ",
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ class BraintrustScoringImpl(
|
|||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from .config import LlmAsJudgeScoringConfig
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
LLM_JUDGE_FN = LlmAsJudgeScoringFn
|
||||
|
||||
|
||||
class LlmAsJudgeScoringImpl(
|
||||
|
|
@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl(
|
|||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.inference_api = inference_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for fn in LLM_JUDGE_FNS:
|
||||
impl = fn(inference_api=self.inference_api)
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
self.llm_as_judge_fn = impl
|
||||
impl = LLM_JUDGE_FN(inference_api=self.inference_api)
|
||||
self.llm_as_judge_fn = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
assert f.identifier.startswith("llm-as-judge"), (
|
||||
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||
)
|
||||
|
|
@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl(
|
|||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
|
|
@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl(
|
|||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn = self.llm_as_judge_fn
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
|
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
judge_response = await self.inference_api.chat_completion(
|
||||
model_id=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
UserMessage(
|
||||
content=judge_input_msg,
|
||||
),
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ class TelemetryConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
self._local.conn = sqlite3.connect(self.conn_string)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to SQLite database: {e}")
|
||||
raise e
|
||||
raise
|
||||
return self._local.conn
|
||||
|
||||
def setup_database(self):
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
|||
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
|
|
@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
|
@ -85,11 +86,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -156,7 +159,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -165,6 +168,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -226,12 +231,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
return {
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logcat.debug("inference", f"params to fireworks: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
|
|
@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -197,8 +199,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ async def convert_chat_completion_request(
|
|||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
|
|
@ -168,7 +168,7 @@ def convert_completion_request(
|
|||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
|
@ -89,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -144,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -153,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -203,12 +208,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
||||
return {
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logcat.debug("inference", f"params to ollama: {params}")
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
|
|
|||
|
|
@ -81,11 +81,13 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
@ -107,7 +109,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -116,6 +118,8 @@ class PassthroughInferenceAdapter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue