mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
fix trace starting in library client
This commit is contained in:
parent
36b4fe02cc
commit
69ad23d41f
1 changed files with 94 additions and 76 deletions
|
@ -67,6 +67,7 @@ def in_notebook():
|
||||||
def stream_across_asyncio_run_boundary(
|
def stream_across_asyncio_run_boundary(
|
||||||
async_gen_maker,
|
async_gen_maker,
|
||||||
pool_executor: ThreadPoolExecutor,
|
pool_executor: ThreadPoolExecutor,
|
||||||
|
path: Optional[str] = None,
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
result_queue = queue.Queue()
|
result_queue = queue.Queue()
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
@ -74,6 +75,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
async def consumer():
|
async def consumer():
|
||||||
# make sure we make the generator in the event loop context
|
# make sure we make the generator in the event loop context
|
||||||
gen = await async_gen_maker()
|
gen = await async_gen_maker()
|
||||||
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
async for item in await gen:
|
async for item in await gen:
|
||||||
result_queue.put(item)
|
result_queue.put(item)
|
||||||
|
@ -85,6 +87,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
finally:
|
finally:
|
||||||
result_queue.put(StopIteration)
|
result_queue.put(StopIteration)
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
def run_async():
|
def run_async():
|
||||||
# Run our own loop to avoid double async generator cleanup which is done
|
# Run our own loop to avoid double async generator cleanup which is done
|
||||||
|
@ -186,14 +189,34 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
return asyncio.run(self.async_client.initialize())
|
return asyncio.run(self.async_client.initialize())
|
||||||
|
|
||||||
|
def _get_path(
|
||||||
|
self,
|
||||||
|
cast_to: Any,
|
||||||
|
options: Any,
|
||||||
|
*,
|
||||||
|
stream=False,
|
||||||
|
stream_cls=None,
|
||||||
|
):
|
||||||
|
return options.url
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
|
path = self._get_path(*args, **kwargs)
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
return stream_across_asyncio_run_boundary(
|
return stream_across_asyncio_run_boundary(
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
lambda: self.async_client.request(*args, **kwargs),
|
||||||
self.pool_executor,
|
self.pool_executor,
|
||||||
|
path=path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
|
||||||
|
async def _traced_request():
|
||||||
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
|
try:
|
||||||
|
return await self.async_client.request(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
|
return asyncio.run(_traced_request())
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -206,7 +229,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
os.environ["TELEMETRY_SINKS"] = "sqlite"
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
|
os.environ["TELEMETRY_SINKS"] = ",".join(
|
||||||
|
sink for sink in current_sinks if sink != "console"
|
||||||
|
)
|
||||||
|
|
||||||
if config_path_or_template_name.endswith(".yaml"):
|
if config_path_or_template_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_template_name)
|
config_path = Path(config_path_or_template_name)
|
||||||
|
@ -295,41 +321,37 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
func = self.endpoint_impls.get(path)
|
||||||
try:
|
if not func:
|
||||||
func = self.endpoint_impls.get(path)
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
result = await func(**body)
|
result = await func(**body)
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=json_content.encode("utf-8"),
|
content=json_content.encode("utf-8"),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
request=httpx.Request(
|
request=httpx.Request(
|
||||||
method=options.method,
|
method=options.method,
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers,
|
headers=options.headers,
|
||||||
json=options.json_data,
|
json=options.json_data,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = APIResponse(
|
response = APIResponse(
|
||||||
raw=mock_response,
|
raw=mock_response,
|
||||||
client=self,
|
client=self,
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
)
|
)
|
||||||
return response.parse()
|
return response.parse()
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
async def _call_streaming(
|
async def _call_streaming(
|
||||||
self,
|
self,
|
||||||
|
@ -341,51 +363,47 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
func = self.endpoint_impls.get(path)
|
||||||
try:
|
if not func:
|
||||||
func = self.endpoint_impls.get(path)
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
sse_event = f"data: {data}\n\n"
|
sse_event = f"data: {data}\n\n"
|
||||||
yield sse_event.encode("utf-8")
|
yield sse_event.encode("utf-8")
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=gen(),
|
content=gen(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
request=httpx.Request(
|
request=httpx.Request(
|
||||||
method=options.method,
|
method=options.method,
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers,
|
headers=options.headers,
|
||||||
json=options.json_data,
|
json=options.json_data,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||||
# so we need to convert it to AsyncStream
|
# so we need to convert it to AsyncStream
|
||||||
args = get_args(stream_cls)
|
args = get_args(stream_cls)
|
||||||
stream_cls = AsyncStream[args[0]]
|
stream_cls = AsyncStream[args[0]]
|
||||||
response = AsyncAPIResponse(
|
response = AsyncAPIResponse(
|
||||||
raw=mock_response,
|
raw=mock_response,
|
||||||
client=self,
|
client=self,
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_cls=stream_cls,
|
stream_cls=stream_cls,
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue