mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fix agentic calling inference
This commit is contained in:
parent
2501b3d7de
commit
f55ffa8b53
4 changed files with 8 additions and 22 deletions
|
@ -132,7 +132,7 @@ async def run_main(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
SearchToolDefinition(engine=SearchEngineType.brave),
|
SearchToolDefinition(engine=SearchEngineType.bing),
|
||||||
WolframAlphaToolDefinition(),
|
WolframAlphaToolDefinition(),
|
||||||
CodeInterpreterToolDefinition(),
|
CodeInterpreterToolDefinition(),
|
||||||
]
|
]
|
||||||
|
|
|
@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
tool_calls = []
|
||||||
model=self.agent_config.model,
|
content = ""
|
||||||
messages=input_messages,
|
stop_reason = None
|
||||||
|
async for chunk in self.inference_api.chat_completion(
|
||||||
|
self.agent_config.model,
|
||||||
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=self._get_tools(),
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
):
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
content = ""
|
|
||||||
stop_reason = None
|
|
||||||
async for chunk in self.inference_api.chat_completion_impl(req):
|
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -77,14 +77,6 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in self.chat_completion_impl(request):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def chat_completion_impl(
|
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> AsyncIterator[
|
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
|
||||||
]:
|
|
||||||
messages = prepare_messages(request)
|
messages = prepare_messages(request)
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(request.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -125,11 +125,7 @@ Event = Annotated[
|
||||||
|
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/log_event")
|
@webmethod(route="/telemetry/log_event")
|
||||||
<<<<<<< migrate_request_wrapper
|
|
||||||
async def log_event(self, event: Event) -> None: ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
=======
|
|
||||||
async def log_event(self, event: Event): ...
|
|
||||||
>>>>>>> main
|
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get_trace", method="GET")
|
@webmethod(route="/telemetry/get_trace", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue