Merge branch 'main' into pr1573

This commit is contained in:
Xi Yan 2025-03-14 12:18:00 -07:00
commit 0e2a13da9c
39 changed files with 5311 additions and 407 deletions

View file

@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Api(Enum):
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"

View file

@ -11,13 +11,6 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
@json_schema_type
class VersionInfo(BaseModel):
version: str
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@json_schema_type
class VersionInfo(BaseModel):
version: str
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import * # noqa: F401 F403

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
config: Dict[str, Any]
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
@runtime_checkable
class Providers(Protocol):
"""
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ...
@webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...

View file

@ -10,7 +10,7 @@ import json
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
if datetime.now(timezone.utc) > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
"-m",
"--model-name",
type=str,
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
"(Run `llama model list` to see a list of valid model names)",
)
self.parser.add_argument(
"-l",
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
]
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
try:
model_id = CoreModelId(args.model_name)
except ValueError:
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
self.parser.error(
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
if model_id not in supported_model_ids:
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
self.parser.error(
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
f"Run `llama model list` to see the valid model names."
)
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"

View file

@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if config.apis:
apis_to_serve = config.apis
else:
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
async def get_provider_impl(config, deps):
impl = ProviderImpl(config, deps)
await impl.initialize()
return impl
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
)
for p in providers
]
)
return ListProvidersResponse(data=ret)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
all_providers = await self.list_providers()
for p in all_providers.data:
if p.provider_id == provider_id:
return p
raise ValueError(f"Provider {provider_id} not found")

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
@ -247,6 +249,25 @@ def sort_providers_by_deps(
)
)
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")

View file

@ -368,6 +368,7 @@ def main():
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,

View file

@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
)
return PromptTemplate(
template_str.lstrip("\n"),
{"today": datetime.now().strftime("%d %B %Y")},
{
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
},
)
def data_examples(self) -> List[Any]:

View file

@ -11,7 +11,7 @@ import re
import secrets
import string
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now().astimezone().isoformat()
now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
start_time = datetime.now(timezone.utc).isoformat()
input_messages = request.messages
output_message = None
@ -295,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -386,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now().astimezone().isoformat()
shield_call_start_time = datetime.now(timezone.utc).isoformat()
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -410,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -433,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -472,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now().astimezone().isoformat()
inference_start_time = datetime.now(timezone.utc).isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
@ -582,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
@ -653,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now().astimezone().isoformat(),
started_at=datetime.now(timezone.utc).isoformat(),
),
)
yield message
@ -670,7 +670,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now().astimezone().isoformat()
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_call = message.tool_calls[0]
tool_result = await self.execute_tool_call_maybe(
session_id,
@ -708,7 +708,7 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now().astimezone().isoformat(),
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)

View file

@ -7,7 +7,7 @@
import json
import logging
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel
@ -36,7 +36,7 @@ class AgentPersistence:
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
started_at=datetime.now(timezone.utc),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
job_status_response.completed_at = datetime.now(timezone.utc)
except Exception:
job_status_response.status = JobStatus.failed

View file

@ -8,7 +8,7 @@ import gc
import logging
import os
import time
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from datetime import datetime
from datetime import datetime, timezone
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)

View file

@ -8,7 +8,7 @@ import json
import os
import sqlite3
import threading
from datetime import datetime
from datetime import datetime, timezone
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
),
)
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
json.dumps(dict(event.attributes)),
),
)

View file

@ -168,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_paths = []
for i, img in enumerate(images):
# create new directory for each day to better organize data:
dump_dname = datetime.today().strftime("%Y-%m-%d")
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
dump_dpath.mkdir(parents=True, exist_ok=True)
# save image into a file

View file

@ -11,7 +11,7 @@ import logging
import queue
import threading
import uuid
from datetime import datetime
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
@ -86,7 +86,7 @@ class TraceContext:
span_id=generate_short_uuid(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(),
start_time=datetime.now(timezone.utc),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(),
timestamp=datetime.now(timezone.utc),
message=self.format(record),
severity=severity(record.levelname),
)