add tracing back to the lib cli (#595)

Adds back all the tracing logic removed from library client. also adds
back the logging to agent_instance.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:44:20 -08:00 committed by GitHub
parent 1c03ba239e
commit e128f2547a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 117 deletions

View file

@ -24,6 +24,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
replace_env_vars,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
T = TypeVar("T")
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return False
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: