mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-07 09:49:05 +00:00
Revert "feat: add a configurable category-based logger (#1352)"
This reverts commit 754feba61f.
This commit is contained in:
parent
b8c519ba11
commit
efe1772727
12 changed files with 54 additions and 407 deletions
|
|
@ -6,8 +6,9 @@
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any, Dict, List, Set, Tuple
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
@ -50,6 +51,8 @@ from llama_stack.providers.datatypes import (
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
@ -184,7 +187,7 @@ def validate_and_prepare_providers(
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
validate_provider(provider, api, provider_registry)
|
validate_provider(provider, api, provider_registry)
|
||||||
|
|
@ -206,11 +209,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
logcat.error("core", p.deprecation_error)
|
log.error(p.deprecation_error)
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
logcat.warning(
|
log.warning(
|
||||||
"core",
|
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -244,10 +246,10 @@ def sort_providers_by_deps(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
log.info(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
log.debug(f" {api_str} => {provider.provider_id}")
|
||||||
return sorted_providers
|
log.debug("")
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
|
|
@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
|
@ -52,6 +51,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,15 +62,12 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing VectorIORouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "VectorIORouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "VectorIORouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
|
@ -81,8 +78,7 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
|
|
@ -99,8 +95,7 @@ class VectorIORouter(VectorIO):
|
||||||
chunks: List[Chunk],
|
chunks: List[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
@ -111,7 +106,6 @@ class VectorIORouter(VectorIO):
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,15 +116,12 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing InferenceRouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "InferenceRouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "InferenceRouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
|
|
@ -141,10 +132,6 @@ class InferenceRouter(Inference):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug(
|
|
||||||
"core",
|
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
|
||||||
)
|
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
|
@ -160,8 +147,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
|
@ -226,8 +212,7 @@ class InferenceRouter(Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||||
)
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
|
@ -257,7 +242,6 @@ class InferenceRouter(Inference):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
@ -277,15 +261,12 @@ class SafetyRouter(Safety):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "SafetyRouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "SafetyRouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
|
|
@ -295,7 +276,6 @@ class SafetyRouter(Safety):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
|
|
@ -304,7 +284,6 @@ class SafetyRouter(Safety):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -317,15 +296,12 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing DatasetIORouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "DatasetIORouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
|
|
@ -335,8 +311,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||||
|
|
@ -347,7 +322,6 @@ class DatasetIORouter(DatasetIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
|
|
@ -359,15 +333,12 @@ class ScoringRouter(Scoring):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "ScoringRouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "ScoringRouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
|
|
@ -376,7 +347,6 @@ class ScoringRouter(Scoring):
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
|
@ -397,8 +367,7 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logcat.debug(
|
logger.debug(
|
||||||
"core",
|
|
||||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||||
)
|
)
|
||||||
res = {}
|
res = {}
|
||||||
|
|
@ -418,15 +387,12 @@ class EvalRouter(Eval):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "EvalRouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "EvalRouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
|
|
@ -434,7 +400,6 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
benchmark_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
|
|
@ -447,7 +412,6 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
|
|
@ -460,7 +424,6 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
|
|
@ -468,7 +431,6 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
|
@ -479,7 +441,6 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
|
@ -492,7 +453,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
|
|
@ -501,7 +461,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
content, vector_db_ids, query_config
|
content, vector_db_ids, query_config
|
||||||
)
|
)
|
||||||
|
|
@ -512,10 +471,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug(
|
|
||||||
"core",
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
)
|
)
|
||||||
|
|
@ -524,7 +479,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
|
@ -533,15 +487,12 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
|
@ -550,5 +501,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
|
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||||
logcat.init()
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
|
||||||
not block the current execution.
|
not block the current execution.
|
||||||
"""
|
"""
|
||||||
signame = signal.Signals(signum).name
|
signame = signal.Signals(signum).name
|
||||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||||
|
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
try:
|
try:
|
||||||
# Gracefully shut down implementations
|
# Gracefully shut down implementations
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logcat.info("server", f"Shutting down {impl_name}")
|
logger.info("Shutting down %s", impl_name)
|
||||||
try:
|
try:
|
||||||
if hasattr(impl, "shutdown"):
|
if hasattr(impl, "shutdown"):
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
else:
|
else:
|
||||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
logger.warning("No shutdown method for %s", impl_name)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||||
|
|
||||||
# Gather all running tasks
|
# Gather all running tasks
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
logger.exception("Timeout while waiting for tasks to finish")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logcat.info("server", "Starting up")
|
logger.info("Starting up")
|
||||||
yield
|
yield
|
||||||
logcat.info("server", "Shutting down")
|
logger.info("Shutting down")
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
@ -209,11 +209,11 @@ async def sse_generator(event_gen):
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logcat.info("server", "Generator cancelled")
|
print("Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logcat.exception("server", f"Error in sse_generator: {e}")
|
logger.exception(f"Error in sse_generator: {e}")
|
||||||
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -235,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logcat.exception("server", f"Error in {func.__name__}")
|
traceback.print_exception(e)
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
@ -314,8 +314,6 @@ class ClientVersionMiddleware:
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logcat.init()
|
|
||||||
|
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -355,10 +353,10 @@ def main():
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logcat.error("server", f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.yaml_config:
|
if args.yaml_config:
|
||||||
|
|
@ -366,12 +364,12 @@ def main():
|
||||||
config_file = Path(args.yaml_config)
|
config_file = Path(args.yaml_config)
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Config file {config_file} does not exist")
|
raise ValueError(f"Config file {config_file} does not exist")
|
||||||
logcat.info("server", f"Using config file: {config_file}")
|
logger.info(f"Using config file: {config_file}")
|
||||||
elif args.template:
|
elif args.template:
|
||||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
|
|
@ -379,10 +377,9 @@ def main():
|
||||||
config = replace_env_vars(yaml.safe_load(fp))
|
config = replace_env_vars(yaml.safe_load(fp))
|
||||||
config = StackRunConfig(**config)
|
config = StackRunConfig(**config)
|
||||||
|
|
||||||
logcat.info("server", "Run configuration:")
|
logger.info("Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(config.model_dump())
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
logger.info(yaml.dump(safe_config, indent=2))
|
||||||
logcat.info("server", log_line)
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
|
@ -392,7 +389,7 @@ def main():
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
logcat.error("server", f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
@ -437,8 +434,9 @@ def main():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||||
|
|
||||||
|
print("")
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||||
|
|
@ -464,10 +462,10 @@ def main():
|
||||||
"ssl_keyfile": keyfile,
|
"ssl_keyfile": keyfile,
|
||||||
"ssl_certfile": certfile,
|
"ssl_certfile": certfile,
|
||||||
}
|
}
|
||||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||||
|
|
||||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
logger.info(f"Listening on {listen_host}:{port}")
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
"app": app,
|
"app": app,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -13,7 +14,6 @@ from typing import Any, Dict, Optional
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batch_inference import BatchInference
|
from llama_stack.apis.batch_inference import BatchInference
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
|
@ -41,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
VectorDBs,
|
VectorDBs,
|
||||||
|
|
@ -101,11 +103,12 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
objects_to_process = response.data if hasattr(response, "data") else response
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
for obj in objects_to_process:
|
for obj in objects_to_process:
|
||||||
logcat.debug(
|
log.info(
|
||||||
"core",
|
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log.info("")
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
|
|
|
||||||
|
|
@ -98,8 +98,9 @@ case "$env_type" in
|
||||||
*)
|
*)
|
||||||
esac
|
esac
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
set -x
|
|
||||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||||
--yaml-config "$yaml_config" \
|
--yaml-config "$yaml_config" \
|
||||||
--port "$port" \
|
--port "$port" \
|
||||||
|
|
@ -141,8 +142,6 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||||
-p $port:$port \
|
-p $port:$port \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
|
|
|
||||||
|
|
@ -1,204 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Category-based logging utility for llama-stack.
|
|
||||||
|
|
||||||
This module provides a wrapper over the standard Python logging module that supports
|
|
||||||
categorized logging with environment variable control.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from llama_stack import logcat
|
|
||||||
logcat.info("server", "Starting up...")
|
|
||||||
logcat.debug("inference", "Processing request...")
|
|
||||||
|
|
||||||
Environment variable:
|
|
||||||
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
|
||||||
Example: "server=debug;inference=warning"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
# ANSI color codes for terminal output
|
|
||||||
COLORS = {
|
|
||||||
"RESET": "\033[0m",
|
|
||||||
"DEBUG": "\033[36m", # Cyan
|
|
||||||
"INFO": "\033[32m", # Green
|
|
||||||
"WARNING": "\033[33m", # Yellow
|
|
||||||
"ERROR": "\033[31m", # Red
|
|
||||||
"CRITICAL": "\033[35m", # Magenta
|
|
||||||
"DIM": "\033[2m", # Dimmed text
|
|
||||||
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
|
||||||
}
|
|
||||||
|
|
||||||
# Static list of valid categories representing various parts of the Llama Stack
|
|
||||||
# server codebase
|
|
||||||
CATEGORIES = [
|
|
||||||
"core",
|
|
||||||
"server",
|
|
||||||
"router",
|
|
||||||
"inference",
|
|
||||||
"agents",
|
|
||||||
"safety",
|
|
||||||
"eval",
|
|
||||||
"tools",
|
|
||||||
"client",
|
|
||||||
]
|
|
||||||
|
|
||||||
_logger = logging.getLogger("llama_stack")
|
|
||||||
_logger.propagate = False
|
|
||||||
|
|
||||||
_default_level = logging.INFO
|
|
||||||
|
|
||||||
# Category-level mapping (can be modified by environment variables)
|
|
||||||
_category_levels: Dict[str, int] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class TerminalStreamHandler(logging.StreamHandler):
|
|
||||||
def __init__(self, stream=None):
|
|
||||||
super().__init__(stream)
|
|
||||||
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
|
||||||
|
|
||||||
def format(self, record):
|
|
||||||
record.is_tty = self.is_tty
|
|
||||||
return super().format(record)
|
|
||||||
|
|
||||||
|
|
||||||
class ColoredFormatter(logging.Formatter):
|
|
||||||
"""Custom formatter with colors and fixed-width level names"""
|
|
||||||
|
|
||||||
def format(self, record):
|
|
||||||
levelname = record.levelname
|
|
||||||
# Use only time with milliseconds, not date
|
|
||||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
|
||||||
|
|
||||||
file_info = f"{record.filename}:{record.lineno}"
|
|
||||||
|
|
||||||
# Get category from extra if available
|
|
||||||
category = getattr(record, "category", None)
|
|
||||||
msg = record.getMessage()
|
|
||||||
|
|
||||||
if getattr(record, "is_tty", False):
|
|
||||||
color = COLORS.get(levelname, COLORS["RESET"])
|
|
||||||
if category:
|
|
||||||
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
|
||||||
formatted_msg = (
|
|
||||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
|
||||||
f"{file_info:<20} {category_formatted}{msg}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
formatted_msg = (
|
|
||||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
|
||||||
f"{file_info:<20} {msg}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if category:
|
|
||||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
|
||||||
else:
|
|
||||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
|
||||||
|
|
||||||
return formatted_msg
|
|
||||||
|
|
||||||
|
|
||||||
def init(default_level: int = logging.INFO) -> None:
|
|
||||||
global _default_level, _category_levels, _logger
|
|
||||||
|
|
||||||
_default_level = default_level
|
|
||||||
|
|
||||||
_logger.setLevel(logging.DEBUG)
|
|
||||||
_logger.handlers = [] # Clear existing handlers
|
|
||||||
|
|
||||||
# Add our custom handler with the colored formatter
|
|
||||||
handler = TerminalStreamHandler()
|
|
||||||
formatter = ColoredFormatter()
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
_logger.addHandler(handler)
|
|
||||||
|
|
||||||
for category in CATEGORIES:
|
|
||||||
_category_levels[category] = default_level
|
|
||||||
|
|
||||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
|
||||||
if env_config:
|
|
||||||
for pair in env_config.split(";"):
|
|
||||||
if not pair.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
category, level = pair.split("=", 1)
|
|
||||||
category = category.strip().lower()
|
|
||||||
level = level.strip().lower()
|
|
||||||
|
|
||||||
level_value = {
|
|
||||||
"debug": logging.DEBUG,
|
|
||||||
"info": logging.INFO,
|
|
||||||
"warning": logging.WARNING,
|
|
||||||
"warn": logging.WARNING,
|
|
||||||
"error": logging.ERROR,
|
|
||||||
"critical": logging.CRITICAL,
|
|
||||||
}.get(level)
|
|
||||||
|
|
||||||
if level_value is None:
|
|
||||||
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if category == "all":
|
|
||||||
for cat in CATEGORIES:
|
|
||||||
_category_levels[cat] = level_value
|
|
||||||
else:
|
|
||||||
if category in CATEGORIES:
|
|
||||||
_category_levels[category] = level_value
|
|
||||||
else:
|
|
||||||
_logger.warning(f"Unknown logging category: {category}")
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
_logger.warning(f"Invalid logging configuration: {pair}")
|
|
||||||
|
|
||||||
|
|
||||||
def _should_log(level: int, category: str) -> bool:
|
|
||||||
category = category.lower()
|
|
||||||
if category not in _category_levels:
|
|
||||||
return False
|
|
||||||
category_level = _category_levels[category]
|
|
||||||
return level >= category_level
|
|
||||||
|
|
||||||
|
|
||||||
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
if _should_log(level, category):
|
|
||||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
|
||||||
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def info(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
warning(category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def error(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
|
||||||
if _should_log(logging.ERROR, category):
|
|
||||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
|
||||||
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
|
||||||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
|
@ -231,14 +230,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
params = {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||||
}
|
}
|
||||||
logcat.debug("inference", f"params to fireworks: {params}")
|
|
||||||
return params
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
|
@ -208,14 +207,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||||
|
|
||||||
params = {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"options": sampling_options,
|
"options": sampling_options,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
}
|
}
|
||||||
logcat.debug("inference", f"params to ollama: {params}")
|
|
||||||
return params
|
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
|
@ -218,14 +217,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
params = {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||||
}
|
}
|
||||||
logcat.debug("inference", f"params to together: {params}")
|
|
||||||
return params
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
|
@ -109,8 +108,6 @@ class LiteLLMOpenAIMixin(
|
||||||
)
|
)
|
||||||
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
logcat.debug("inference", f"params to litellm (openai compat): {params}")
|
|
||||||
|
|
||||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||||
# caches various httpx.client objects in a non-eventloop aware manner
|
# caches various httpx.client objects in a non-eventloop aware manner
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,6 @@ exclude = [
|
||||||
"llama_stack/distribution",
|
"llama_stack/distribution",
|
||||||
"llama_stack/apis",
|
"llama_stack/apis",
|
||||||
"llama_stack/cli",
|
"llama_stack/cli",
|
||||||
"llama_stack/logcat.py",
|
|
||||||
"llama_stack/models",
|
"llama_stack/models",
|
||||||
"llama_stack/strong_typing",
|
"llama_stack/strong_typing",
|
||||||
"llama_stack/templates",
|
"llama_stack/templates",
|
||||||
|
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from llama_stack import logcat
|
|
||||||
|
|
||||||
|
|
||||||
class TestLogcat(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
|
|
||||||
|
|
||||||
self.log_output = io.StringIO()
|
|
||||||
self._init_logcat()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
if self.original_env is not None:
|
|
||||||
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
|
|
||||||
else:
|
|
||||||
os.environ.pop("LLAMA_STACK_LOGGING", None)
|
|
||||||
|
|
||||||
def _init_logcat(self):
|
|
||||||
logcat.init(default_level=logging.DEBUG)
|
|
||||||
self.handler = logging.StreamHandler(self.log_output)
|
|
||||||
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
|
|
||||||
logcat._logger.handlers.clear()
|
|
||||||
logcat._logger.addHandler(self.handler)
|
|
||||||
|
|
||||||
def test_basic_logging(self):
|
|
||||||
logcat.info("server", "Info message")
|
|
||||||
logcat.warning("server", "Warning message")
|
|
||||||
logcat.error("server", "Error message")
|
|
||||||
|
|
||||||
output = self.log_output.getvalue()
|
|
||||||
self.assertIn("[server] Info message", output)
|
|
||||||
self.assertIn("[server] Warning message", output)
|
|
||||||
self.assertIn("[server] Error message", output)
|
|
||||||
|
|
||||||
def test_different_categories(self):
|
|
||||||
# Log messages with different categories
|
|
||||||
logcat.info("server", "Server message")
|
|
||||||
logcat.info("inference", "Inference message")
|
|
||||||
logcat.info("router", "Router message")
|
|
||||||
|
|
||||||
output = self.log_output.getvalue()
|
|
||||||
self.assertIn("[server] Server message", output)
|
|
||||||
self.assertIn("[inference] Inference message", output)
|
|
||||||
self.assertIn("[router] Router message", output)
|
|
||||||
|
|
||||||
def test_env_var_control(self):
|
|
||||||
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
|
|
||||||
self._init_logcat()
|
|
||||||
|
|
||||||
# These should be visible based on the environment settings
|
|
||||||
logcat.debug("server", "Server debug message")
|
|
||||||
logcat.info("server", "Server info message")
|
|
||||||
logcat.warning("inference", "Inference warning message")
|
|
||||||
logcat.error("inference", "Inference error message")
|
|
||||||
|
|
||||||
# These should be filtered out based on the environment settings
|
|
||||||
logcat.debug("inference", "Inference debug message")
|
|
||||||
logcat.info("inference", "Inference info message")
|
|
||||||
|
|
||||||
output = self.log_output.getvalue()
|
|
||||||
self.assertIn("[server] Server debug message", output)
|
|
||||||
self.assertIn("[server] Server info message", output)
|
|
||||||
self.assertIn("[inference] Inference warning message", output)
|
|
||||||
self.assertIn("[inference] Inference error message", output)
|
|
||||||
|
|
||||||
self.assertNotIn("[inference] Inference debug message", output)
|
|
||||||
self.assertNotIn("[inference] Inference info message", output)
|
|
||||||
|
|
||||||
def test_invalid_category(self):
|
|
||||||
logcat.info("nonexistent", "This message should not be logged")
|
|
||||||
|
|
||||||
# Check that the message was not logged
|
|
||||||
output = self.log_output.getvalue()
|
|
||||||
self.assertNotIn("[nonexistent] This message should not be logged", output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue