mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
fix trace starting in library client (#655)
# What does this PR do? Because of the way library client sets up async io boundaries, tracing was broken with streaming. This PR fixes the tracing to start at the right way to caputre the life time of async gen functions correctly. Test plan: Script ran: https://gist.github.com/yanxi0830/f6645129e55ab12de3cd6ec71564c69e Before: No spans returned for a session Now: We see spans <img width="1678" alt="Screenshot 2024-12-18 at 9 50 46 PM" src="https://github.com/user-attachments/assets/58a3b0dd-a41c-489a-b89a-075e698a2c03" />
This commit is contained in:
parent
ddf37ea467
commit
8b8d1c1ef4
1 changed files with 94 additions and 76 deletions
|
@ -67,6 +67,7 @@ def in_notebook():
|
||||||
def stream_across_asyncio_run_boundary(
|
def stream_across_asyncio_run_boundary(
|
||||||
async_gen_maker,
|
async_gen_maker,
|
||||||
pool_executor: ThreadPoolExecutor,
|
pool_executor: ThreadPoolExecutor,
|
||||||
|
path: Optional[str] = None,
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
result_queue = queue.Queue()
|
result_queue = queue.Queue()
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
@ -74,6 +75,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
async def consumer():
|
async def consumer():
|
||||||
# make sure we make the generator in the event loop context
|
# make sure we make the generator in the event loop context
|
||||||
gen = await async_gen_maker()
|
gen = await async_gen_maker()
|
||||||
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
async for item in await gen:
|
async for item in await gen:
|
||||||
result_queue.put(item)
|
result_queue.put(item)
|
||||||
|
@ -85,6 +87,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
finally:
|
finally:
|
||||||
result_queue.put(StopIteration)
|
result_queue.put(StopIteration)
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
def run_async():
|
def run_async():
|
||||||
# Run our own loop to avoid double async generator cleanup which is done
|
# Run our own loop to avoid double async generator cleanup which is done
|
||||||
|
@ -186,14 +189,34 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
return asyncio.run(self.async_client.initialize())
|
return asyncio.run(self.async_client.initialize())
|
||||||
|
|
||||||
|
def _get_path(
|
||||||
|
self,
|
||||||
|
cast_to: Any,
|
||||||
|
options: Any,
|
||||||
|
*,
|
||||||
|
stream=False,
|
||||||
|
stream_cls=None,
|
||||||
|
):
|
||||||
|
return options.url
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
|
path = self._get_path(*args, **kwargs)
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
return stream_across_asyncio_run_boundary(
|
return stream_across_asyncio_run_boundary(
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
lambda: self.async_client.request(*args, **kwargs),
|
||||||
self.pool_executor,
|
self.pool_executor,
|
||||||
|
path=path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
|
||||||
|
async def _traced_request():
|
||||||
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
|
try:
|
||||||
|
return await self.async_client.request(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
|
return asyncio.run(_traced_request())
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -206,7 +229,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
os.environ["TELEMETRY_SINKS"] = "sqlite"
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
|
os.environ["TELEMETRY_SINKS"] = ",".join(
|
||||||
|
sink for sink in current_sinks if sink != "console"
|
||||||
|
)
|
||||||
|
|
||||||
if config_path_or_template_name.endswith(".yaml"):
|
if config_path_or_template_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_template_name)
|
config_path = Path(config_path_or_template_name)
|
||||||
|
@ -295,41 +321,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
func = self.endpoint_impls.get(path)
|
||||||
try:
|
if not func:
|
||||||
func = self.endpoint_impls.get(path)
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
result = await func(**body)
|
result = await func(**body)
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=json_content.encode("utf-8"),
|
content=json_content.encode("utf-8"),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
request=httpx.Request(
|
request=httpx.Request(
|
||||||
method=options.method,
|
method=options.method,
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers,
|
headers=options.headers,
|
||||||
json=options.json_data,
|
json=options.json_data,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = APIResponse(
|
response = APIResponse(
|
||||||
raw=mock_response,
|
raw=mock_response,
|
||||||
client=self,
|
client=self,
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
)
|
)
|
||||||
return response.parse()
|
return response.parse()
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
async def _call_streaming(
|
async def _call_streaming(
|
||||||
self,
|
self,
|
||||||
|
@ -341,51 +363,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
func = self.endpoint_impls.get(path)
|
||||||
try:
|
if not func:
|
||||||
func = self.endpoint_impls.get(path)
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
sse_event = f"data: {data}\n\n"
|
sse_event = f"data: {data}\n\n"
|
||||||
yield sse_event.encode("utf-8")
|
yield sse_event.encode("utf-8")
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=gen(),
|
content=gen(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
request=httpx.Request(
|
request=httpx.Request(
|
||||||
method=options.method,
|
method=options.method,
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers,
|
headers=options.headers,
|
||||||
json=options.json_data,
|
json=options.json_data,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||||
# so we need to convert it to AsyncStream
|
# so we need to convert it to AsyncStream
|
||||||
args = get_args(stream_cls)
|
args = get_args(stream_cls)
|
||||||
stream_cls = AsyncStream[args[0]]
|
stream_cls = AsyncStream[args[0]]
|
||||||
response = AsyncAPIResponse(
|
response = AsyncAPIResponse(
|
||||||
raw=mock_response,
|
raw=mock_response,
|
||||||
client=self,
|
client=self,
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_cls=stream_cls,
|
stream_cls=stream_cls,
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue