mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
fix trace starting in library client
This commit is contained in:
parent
36b4fe02cc
commit
69ad23d41f
1 changed files with 94 additions and 76 deletions
|
@ -67,6 +67,7 @@ def in_notebook():
|
|||
def stream_across_asyncio_run_boundary(
|
||||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
path: Optional[str] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
@ -74,6 +75,7 @@ def stream_across_asyncio_run_boundary(
|
|||
async def consumer():
|
||||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
async for item in await gen:
|
||||
result_queue.put(item)
|
||||
|
@ -85,6 +87,7 @@ def stream_across_asyncio_run_boundary(
|
|||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
stop_event.set()
|
||||
await end_trace()
|
||||
|
||||
def run_async():
|
||||
# Run our own loop to avoid double async generator cleanup which is done
|
||||
|
@ -186,14 +189,34 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
|
||||
def _get_path(
|
||||
self,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
return options.url
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
path = self._get_path(*args, **kwargs)
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
path=path,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||
|
||||
async def _traced_request():
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
return await self.async_client.request(*args, **kwargs)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return asyncio.run(_traced_request())
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
|
@ -206,7 +229,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
os.environ["TELEMETRY_SINKS"] = "sqlite"
|
||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(
|
||||
sink for sink in current_sinks if sink != "console"
|
||||
)
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
|
@ -295,41 +321,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
result = await func(**body)
|
||||
body = self._convert_body(path, body)
|
||||
result = await func(**body)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
)
|
||||
return response.parse()
|
||||
|
||||
async def _call_streaming(
|
||||
self,
|
||||
|
@ -341,51 +363,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
body = self._convert_body(path, body)
|
||||
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=gen(),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
request=httpx.Request(
|
||||
method=options.method,
|
||||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers,
|
||||
json=options.json_data,
|
||||
),
|
||||
)
|
||||
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
finally:
|
||||
await end_trace()
|
||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
cast_to=cast_to,
|
||||
options=options,
|
||||
stream=True,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue