fix trace starting in library client

This commit is contained in:
Dinesh Yeduguru 2024-12-18 21:46:29 -08:00
parent 36b4fe02cc
commit 69ad23d41f

View file

@ -67,6 +67,7 @@ def in_notebook():
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
@ -74,6 +75,7 @@ def stream_across_asyncio_run_boundary(
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
try:
async for item in await gen:
result_queue.put(item)
@ -85,6 +87,7 @@ def stream_across_asyncio_run_boundary(
finally:
result_queue.put(StopIteration)
stop_event.set()
await end_trace()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
@ -186,14 +189,34 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.initialize())
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
)
else:
return asyncio.run(self.async_client.request(*args, **kwargs))
async def _traced_request():
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)
finally:
await end_trace()
return asyncio.run(_traced_request())
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -206,7 +229,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
os.environ["TELEMETRY_SINKS"] = "sqlite"
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(
sink for sink in current_sinks if sink != "console"
)
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
@ -295,41 +321,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
result = await func(**body)
body = self._convert_body(path, body)
result = await func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
finally:
await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
response = APIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=False,
stream_cls=None,
)
return response.parse()
async def _call_streaming(
self,
@ -341,51 +363,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
body = self._convert_body(path, body)
async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
async def gen():
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
sse_event = f"data: {data}\n\n"
yield sse_event.encode("utf-8")
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
headers={
"Content-Type": "application/json",
},
request=httpx.Request(
method=options.method,
url=options.url,
params=options.params,
headers=options.headers,
json=options.json_data,
),
)
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
finally:
await end_trace()
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
cast_to=cast_to,
options=options,
stream=True,
stream_cls=stream_cls,
)
return await response.parse()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: