mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
tmp
This commit is contained in:
parent
753a1aa7bc
commit
5595f5b9b8
3 changed files with 18 additions and 0 deletions
|
@ -168,10 +168,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
print("create_and_execute_turn")
|
||||||
|
pprint(request)
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
for i, turn in enumerate(turns):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
print("create_and_execute_turn turn to messages")
|
||||||
|
pprint(messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
|
@ -360,6 +367,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
print("_run messages", input_messages)
|
||||||
# TODO: simplify all of this code, it can be simpler
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
toolgroups = set()
|
toolgroups = set()
|
||||||
|
@ -490,6 +498,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference") as span:
|
with tracing.span("inference") as span:
|
||||||
|
print("just before chat completion", input_messages)
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
|
|
|
@ -196,6 +196,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
print("inside together chat completion messages", messages)
|
||||||
|
breakpoint()
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -223,6 +225,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
print("together stream completion")
|
||||||
|
pprint(request)
|
||||||
|
pprint(params)
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -240,6 +247,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
breakpoint()
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present:
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
|
|
|
@ -375,6 +375,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama_3_2(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
|
breakpoint()
|
||||||
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
|
||||||
|
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue