mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 17:34:32 +00:00
Merge-related changes.
This commit is contained in:
commit
a714bbac9d
95 changed files with 11044 additions and 4639 deletions
|
|
@ -16,7 +16,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -95,7 +95,7 @@ def build_image(
|
|||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == ImageType.container.value:
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -104,7 +104,7 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
@ -112,7 +112,7 @@ def build_image(
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
elif build_config.image_type == LlamaStackImageType.VENV.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,10 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
|
@ -160,6 +163,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
|
|
@ -262,21 +268,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
if self.provider_data:
|
||||
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
if stream:
|
||||
response = await self._call_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
else:
|
||||
response = await self._call_non_streaming(
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
|
@ -374,9 +384,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
content=wrapped_gen,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -4,16 +4,62 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
|
@ -26,7 +72,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||
val = _provider_data_var.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
@ -36,25 +82,32 @@ class NeedsRequestProviderData:
|
|||
return provider_data
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
"x-llamastack-provider-data",
|
||||
]
|
||||
val = None
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
log.error("Provider data not encoded as a JSON object!")
|
||||
return None
|
||||
|
||||
_THREAD_LOCAL.provider_data_header_value = val
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import importlib
|
|||
import inspect
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -37,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
|
|
@ -53,6 +53,8 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
|
@ -169,9 +171,7 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
deps__=[info.routing_table_api.value],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -192,7 +192,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
|
@ -214,11 +214,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
|
@ -252,9 +251,10 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
|
@ -69,17 +69,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
|
|||
"tool_runtime": ToolRuntimeRouter,
|
||||
"preprocessing": PreprocessingRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
|
@ -22,10 +20,6 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -33,14 +27,13 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.preprocessing import (
|
||||
Preprocessing,
|
||||
PreprocessingDataElement,
|
||||
|
|
@ -55,7 +48,6 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -66,10 +58,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.distribution.utils.chain import execute_preprocessor_chain
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
|
@ -79,15 +71,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing VectorIORouter")
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.initialize")
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.shutdown")
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
|
@ -98,10 +90,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
|
@ -116,8 +105,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
|
@ -128,7 +116,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
|
@ -138,21 +126,16 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing InferenceRouter")
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.initialize")
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.shutdown")
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
|
@ -163,63 +146,11 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
) -> List[MetricEvent]:
|
||||
span = get_current_span()
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -232,9 +163,8 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logcat.debug(
|
||||
"core",
|
||||
) -> AsyncGenerator:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
|
@ -283,47 +213,10 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
return await provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -336,8 +229,7 @@ class InferenceRouter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
|
@ -354,41 +246,10 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
else:
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
return await provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -398,7 +259,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
|
@ -418,15 +279,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing SafetyRouter")
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.initialize")
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.shutdown")
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
|
@ -436,7 +297,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
|
@ -445,7 +306,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
|
@ -458,15 +319,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing DatasetIORouter")
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.initialize")
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
|
|
@ -476,8 +337,7 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
|
|
@ -488,7 +348,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
|
@ -500,15 +360,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ScoringRouter")
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.initialize")
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.shutdown")
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
|
@ -517,7 +377,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
|
@ -538,10 +398,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||
)
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
|
@ -559,15 +416,15 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing EvalRouter")
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.initialize")
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.shutdown")
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
|
|
@ -575,7 +432,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
|
|
@ -588,7 +445,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
|
@ -601,7 +458,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
|
@ -609,7 +466,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -620,7 +477,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
|
@ -633,7 +490,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
|
@ -642,7 +499,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
|
@ -654,9 +511,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
chunk_size_in_tokens: int = 512,
|
||||
preprocessor_chain: Optional[PreprocessorChain] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
|
||||
|
|
@ -666,7 +522,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
|
@ -675,15 +531,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
|
@ -692,7 +548,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
||||
|
||||
|
|
@ -701,15 +557,15 @@ class PreprocessingRouter(Preprocessing):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing PreprocessingRouter")
|
||||
logger.debug("Initializing PreprocessingRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "PreprocessingRouter.initialize")
|
||||
logger.debug("PreprocessingRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "PreprocessingRouter.shutdown")
|
||||
logger.debug("PreprocessingRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def preprocess(
|
||||
|
|
@ -717,8 +573,7 @@ class PreprocessingRouter(Preprocessing):
|
|||
preprocessors: PreprocessorChain,
|
||||
preprocessor_inputs: List[PreprocessingDataElement],
|
||||
) -> PreprocessorResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"PreprocessingRouter.chain_preprocess: preprocessors {[p.preprocessor_id for p in preprocessors]}, {len(preprocessor_inputs)} inputs",
|
||||
)
|
||||
preprocessor_impls = [self.routing_table.get_provider_impl(p.preprocessor_id) for p in preprocessors]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -18,7 +17,7 @@ import warnings
|
|||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
|
|
@ -28,10 +27,12 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
|
|
@ -39,6 +40,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
|
|
@ -54,8 +56,7 @@ from .endpoints import get_all_api_endpoints
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logcat.init()
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
|
@ -142,23 +143,23 @@ def handle_signal(app, signum, _) -> None:
|
|||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logcat.info("server", f"Shutting down {impl_name}")
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -172,7 +173,7 @@ def handle_signal(app, signum, _) -> None:
|
|||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
|
|
@ -184,9 +185,9 @@ def handle_signal(app, signum, _) -> None:
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logcat.info("server", "Starting up")
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
logcat.info("server", "Shutting down")
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
|
@ -204,16 +205,14 @@ async def maybe_await(value):
|
|||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
async for item in await event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logcat.info("server", "Generator cancelled")
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in sse_generator: {e}")
|
||||
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
|
@ -225,18 +224,20 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
set_request_provider_data(request.headers)
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in {func.__name__}")
|
||||
raise translate_exception(e) from e
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
@ -313,13 +314,17 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main():
|
||||
logcat.init()
|
||||
|
||||
def main(args: Optional[argparse.Namespace] = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
dest="config",
|
||||
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -349,29 +354,41 @@ def main():
|
|||
required="--tls-keyfile" in sys.argv,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for deprecated argument usage
|
||||
if "--yaml-config" in sys.argv:
|
||||
warnings.warn(
|
||||
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
if args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logcat.info("server", f"Using config file: {config_file}")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
|
|
@ -379,10 +396,9 @@ def main():
|
|||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
logcat.info("server", "Run configuration:")
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||
logcat.info("server", log_line)
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
|
@ -392,7 +408,7 @@ def main():
|
|||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
|
|
@ -437,7 +453,7 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
|
@ -464,10 +480,10 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ import tempfile
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
|
@ -41,8 +39,11 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
VectorDBs,
|
||||
|
|
@ -106,9 +107,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -100,12 +100,15 @@ esac
|
|||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
|
|
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
|
|
|
|||
|
|
@ -20,14 +20,14 @@ import importlib
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.image_types import ImageType
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == ImageType.container.value or config.container_image:
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
elif image_type == ImageType.conda.value:
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import enum
|
||||
|
||||
|
||||
class ImageType(Enum):
|
||||
container = "container"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
class LlamaStackImageType(enum.Enum):
|
||||
CONTAINER = "container"
|
||||
CONDA = "conda"
|
||||
VENV = "venv"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue