mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 21:10:00 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
8d7bb1140f
20 changed files with 381 additions and 414 deletions
|
|
@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
|
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
|
|
@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
|
|
@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
|
||||
delta = event.delta
|
||||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_calls.append(delta.content)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
elif delta.type == "text":
|
||||
content += delta.text
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta=event.delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
|
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
|||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
|
|
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
|||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
|
|
@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
|
|||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
|
|
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
|
|||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
|
|
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
global _TRACER_PROVIDER
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
|
|||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
|
@ -162,7 +165,7 @@ def convert_chat_completion_response(
|
|||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
|
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
|
|||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
|
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=choice.delta.content or "",
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
|
|
@ -213,7 +215,7 @@ async def convert_chat_completion_response_stream(
|
|||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=choice.delta.content or "",
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
|
|||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
|
||||
OpenAI ChatCompletionChunk:
|
||||
choices: List[Choice]
|
||||
|
||||
OpenAI Choice: # different from the non-streamed Choice
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChoiceDelta:
|
||||
content: Optional[str]
|
||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCall:
|
||||
index: int
|
||||
id: Optional[str]
|
||||
function: Optional[ChoiceDeltaToolCallFunction]
|
||||
type: Optional[Literal["function"]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCallFunction:
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponseStreamChunk:
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
ChatCompletionResponseEvent:
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
stop_reason: Optional[StopReason]
|
||||
|
||||
ChatCompletionResponseEventType:
|
||||
start = "start"
|
||||
progress = "progress"
|
||||
complete = "complete"
|
||||
|
||||
ToolCallDelta:
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: str
|
||||
|
||||
ToolCallParseStatus:
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
StopReason:
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
|
|
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
|
@ -561,7 +501,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
|
|
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content or "", # content is not optional
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
|
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
|
|||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||
|
||||
OpenAI CompletionLogprobs:
|
||||
text_offset: Optional[List[int]]
|
||||
token_logprobs: Optional[List[float]]
|
||||
tokens: Optional[List[str]]
|
||||
top_logprobs: Optional[List[Dict[str, float]]]
|
||||
|
||||
->
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
|
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
|
|||
) -> CompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||
|
||||
OpenAI Completion Choice:
|
||||
text: str
|
||||
finish_reason: str
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
->
|
||||
|
||||
CompletionResponse:
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
|
||||
CompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: str | ImageMedia | List[str | ImageMedia]
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
return CompletionResponse(
|
||||
content=choice.text,
|
||||
|
|
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
|
|||
"""
|
||||
Convert a stream of OpenAI Completions into a stream
|
||||
of ChatCompletionResponseStreamChunks.
|
||||
|
||||
OpenAI Completion:
|
||||
id: str
|
||||
choices: List[OpenAICompletionChoice]
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: Optional[str]
|
||||
usage: Optional[OpenAICompletionUsage]
|
||||
|
||||
OpenAI CompletionChoice:
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAILogprobs]
|
||||
text: str
|
||||
|
||||
->
|
||||
|
||||
CompletionResponseStreamChunk:
|
||||
delta: str
|
||||
stop_reason: Optional[StopReason]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
"""
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=choice.text,
|
||||
delta=TextDelta(text=choice.text),
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
|
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -196,7 +195,9 @@ class TestInference:
|
|||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
|
|
@ -463,7 +464,7 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
|
|
@ -474,8 +475,8 @@ class TestInference:
|
|||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Message,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -241,7 +245,7 @@ async def process_chat_completion_stream_response(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -253,7 +257,7 @@ async def process_chat_completion_stream_response(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
|
|||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
"""
|
||||
assert llama_model is not None, "llama_model is required"
|
||||
model = resolve_model(llama_model)
|
||||
if model is None:
|
||||
log.error(f"Could not resolve model {llama_model}")
|
||||
|
|
|
|||
|
|
@ -127,7 +127,8 @@ class TraceContext:
|
|||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(TelemetryHandler())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue