forked from phoenix-oss/llama-stack-mirror
add tracing to library client (#591)
This commit is contained in:
parent
ab7145a04f
commit
bc1fddf1df
2 changed files with 49 additions and 17 deletions
|
@ -22,6 +22,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -29,6 +30,11 @@ from llama_stack.distribution.stack import (
|
|||
get_stack_run_config_from_template,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -187,6 +193,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return False
|
||||
|
||||
# Set up telemetry logger similar to server.py
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
@ -234,21 +244,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
return await self._call_non_streaming(path, "POST", body)
|
||||
|
||||
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return await func(**body)
|
||||
body = self._convert_body(path, body)
|
||||
return await func(**body)
|
||||
finally:
|
||||
end_trace()
|
||||
|
||||
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield chunk
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield chunk
|
||||
finally:
|
||||
end_trace()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue