diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 9ddb070d7..fd782f6c9 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -6372,6 +6372,9 @@
"$ref": "#/components/schemas/TokenLogProbs"
},
"description": "Optional log probabilities for generated tokens"
+ },
+ "usage": {
+ "$ref": "#/components/schemas/UsageInfo"
}
},
"additionalProperties": false,
@@ -6430,6 +6433,31 @@
"title": "TokenLogProbs",
"description": "Log probabilities for generated tokens."
},
+ "UsageInfo": {
+ "type": "object",
+ "properties": {
+ "completion_tokens": {
+ "type": "integer",
+ "description": "Number of tokens generated"
+ },
+ "prompt_tokens": {
+ "type": "integer",
+ "description": "Number of tokens in the prompt"
+ },
+ "total_tokens": {
+ "type": "integer",
+ "description": "Total number of tokens processed"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "completion_tokens",
+ "prompt_tokens",
+ "total_tokens"
+ ],
+ "title": "UsageInfo",
+ "description": "Usage information for a model."
+ },
"BatchCompletionRequest": {
"type": "object",
"properties": {
@@ -10939,6 +10967,31 @@
"title": "OpenAIChatCompletionToolCallFunction",
"description": "Function call details for OpenAI-compatible tool calls."
},
+ "OpenAIChatCompletionUsage": {
+ "type": "object",
+ "properties": {
+ "prompt_tokens": {
+ "type": "integer",
+ "description": "The number of tokens in the prompt"
+ },
+ "completion_tokens": {
+ "type": "integer",
+ "description": "The number of tokens in the completion"
+ },
+ "total_tokens": {
+ "type": "integer",
+ "description": "The total number of tokens used"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt_tokens",
+ "completion_tokens",
+ "total_tokens"
+ ],
+ "title": "OpenAIChatCompletionUsage",
+ "description": "Usage information for an OpenAI-compatible chat completion response."
+ },
"OpenAIChoice": {
"type": "object",
"properties": {
@@ -11276,6 +11329,13 @@
"OpenAICompletionWithInputMessages": {
"type": "object",
"properties": {
+ "metrics": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/MetricInResponse"
+ },
+ "description": "(Optional) List of metrics associated with the API response"
+ },
"id": {
"type": "string",
"description": "The ID of the chat completion"
@@ -11301,6 +11361,9 @@
"type": "string",
"description": "The model that was used to generate the chat completion"
},
+ "usage": {
+ "$ref": "#/components/schemas/OpenAIChatCompletionUsage"
+ },
"input_messages": {
"type": "array",
"items": {
@@ -13062,6 +13125,13 @@
"items": {
"type": "object",
"properties": {
+ "metrics": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/MetricInResponse"
+ },
+ "description": "(Optional) List of metrics associated with the API response"
+ },
"id": {
"type": "string",
"description": "The ID of the chat completion"
@@ -13087,6 +13157,9 @@
"type": "string",
"description": "The model that was used to generate the chat completion"
},
+ "usage": {
+ "$ref": "#/components/schemas/OpenAIChatCompletionUsage"
+ },
"input_messages": {
"type": "array",
"items": {
@@ -14478,6 +14551,13 @@
"OpenAIChatCompletion": {
"type": "object",
"properties": {
+ "metrics": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/MetricInResponse"
+ },
+ "description": "(Optional) List of metrics associated with the API response"
+ },
"id": {
"type": "string",
"description": "The ID of the chat completion"
@@ -14502,6 +14582,9 @@
"model": {
"type": "string",
"description": "The model that was used to generate the chat completion"
+ },
+ "usage": {
+ "$ref": "#/components/schemas/OpenAIChatCompletionUsage"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 94dc5c0f9..d0096e268 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -4548,6 +4548,8 @@ components:
$ref: '#/components/schemas/TokenLogProbs'
description: >-
Optional log probabilities for generated tokens
+ usage:
+ $ref: '#/components/schemas/UsageInfo'
additionalProperties: false
required:
- completion_message
@@ -4589,6 +4591,25 @@ components:
- logprobs_by_token
title: TokenLogProbs
description: Log probabilities for generated tokens.
+ UsageInfo:
+ type: object
+ properties:
+ completion_tokens:
+ type: integer
+ description: Number of tokens generated
+ prompt_tokens:
+ type: integer
+ description: Number of tokens in the prompt
+ total_tokens:
+ type: integer
+ description: Total number of tokens processed
+ additionalProperties: false
+ required:
+ - completion_tokens
+ - prompt_tokens
+ - total_tokens
+ title: UsageInfo
+ description: Usage information for a model.
BatchCompletionRequest:
type: object
properties:
@@ -8103,6 +8124,26 @@ components:
title: OpenAIChatCompletionToolCallFunction
description: >-
Function call details for OpenAI-compatible tool calls.
+ OpenAIChatCompletionUsage:
+ type: object
+ properties:
+ prompt_tokens:
+ type: integer
+ description: The number of tokens in the prompt
+ completion_tokens:
+ type: integer
+ description: The number of tokens in the completion
+ total_tokens:
+ type: integer
+ description: The total number of tokens used
+ additionalProperties: false
+ required:
+ - prompt_tokens
+ - completion_tokens
+ - total_tokens
+ title: OpenAIChatCompletionUsage
+ description: >-
+ Usage information for an OpenAI-compatible chat completion response.
OpenAIChoice:
type: object
properties:
@@ -8365,6 +8406,12 @@ components:
OpenAICompletionWithInputMessages:
type: object
properties:
+ metrics:
+ type: array
+ items:
+ $ref: '#/components/schemas/MetricInResponse'
+ description: >-
+ (Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@@ -8387,6 +8434,8 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
+ usage:
+ $ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages:
type: array
items:
@@ -9682,6 +9731,12 @@ components:
items:
type: object
properties:
+ metrics:
+ type: array
+ items:
+ $ref: '#/components/schemas/MetricInResponse'
+ description: >-
+ (Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@@ -9704,6 +9759,8 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
+ usage:
+ $ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages:
type: array
items:
@@ -10719,6 +10776,12 @@ components:
OpenAIChatCompletion:
type: object
properties:
+ metrics:
+ type: array
+ items:
+ $ref: '#/components/schemas/MetricInResponse'
+ description: >-
+ (Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@@ -10741,6 +10804,8 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
+ usage:
+ $ref: '#/components/schemas/OpenAIChatCompletionUsage'
additionalProperties: false
required:
- id
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index bd4737ca7..1b7869a30 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -451,6 +451,20 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
event: ChatCompletionResponseEvent
+@json_schema_type
+class UsageInfo(BaseModel):
+ """Usage information for a model.
+
+ :param completion_tokens: Number of tokens generated
+ :param prompt_tokens: Number of tokens in the prompt
+ :param total_tokens: Total number of tokens processed
+ """
+
+ completion_tokens: int
+ prompt_tokens: int
+ total_tokens: int
+
+
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
@@ -461,6 +475,7 @@ class ChatCompletionResponse(MetricResponseMixin):
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
+ usage: UsageInfo | None = None
@json_schema_type
@@ -818,7 +833,21 @@ class OpenAIChoice(BaseModel):
@json_schema_type
-class OpenAIChatCompletion(BaseModel):
+class OpenAIChatCompletionUsage(BaseModel):
+ """Usage information for an OpenAI-compatible chat completion response.
+
+ :param prompt_tokens: The number of tokens in the prompt
+ :param completion_tokens: The number of tokens in the completion
+ :param total_tokens: The total number of tokens used
+ """
+
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+
+@json_schema_type
+class OpenAIChatCompletion(MetricResponseMixin):
"""Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
@@ -833,6 +862,7 @@ class OpenAIChatCompletion(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
+ usage: OpenAIChatCompletionUsage | None = None
@json_schema_type
diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py
index 762d7073e..1c356d1f1 100644
--- a/llama_stack/core/routers/inference.py
+++ b/llama_stack/core/routers/inference.py
@@ -590,6 +590,7 @@ class InferenceRouter(Inference):
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
+
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
@@ -739,7 +740,6 @@ class InferenceRouter(Inference):
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
-
try:
async for chunk in response:
# Skip None chunks
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 2c01d192c..fc77a7214 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -130,7 +130,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
- stream = self.client.completions.create(**params)
+ stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@@ -208,9 +208,9 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
params = await self._get_params(request)
if "messages" in params:
- stream = self.client.chat.completions.create(**params)
+ stream = await self.client.chat.completions.create(**params)
else:
- stream = self.client.completions.create(**params)
+ stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 55c2ac0ad..3ef4fb134 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -31,6 +31,8 @@ from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
+from llama_stack.apis.inference.inference import UsageInfo
+
try:
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
@@ -103,6 +105,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
+ OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
@@ -277,6 +280,11 @@ def process_chat_completion_response(
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
+ usage = UsageInfo(
+ prompt_tokens=response.usage.prompt_tokens,
+ completion_tokens=response.usage.completion_tokens,
+ total_tokens=response.usage.total_tokens,
+ )
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
@@ -290,6 +298,7 @@ def process_chat_completion_response(
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
+ usage=usage,
)
else:
# Otherwise, return tool calls as normal
@@ -301,6 +310,7 @@ def process_chat_completion_response(
content="",
),
logprobs=None,
+ usage=usage,
)
# TODO: This does not work well with tool calls for vLLM remote provider
@@ -335,6 +345,7 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls,
),
logprobs=None,
+ usage=usage,
)
@@ -646,7 +657,7 @@ async def convert_message_to_openai_dict_new(
arguments=json.dumps(tool.arguments),
),
type="function",
- )
+ ).model_dump()
for tool in message.tool_calls
]
params = {}
@@ -657,6 +668,7 @@ async def convert_message_to_openai_dict_new(
content=await _convert_message_content(message.content),
**params,
)
+
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
@@ -1375,6 +1387,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
+
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@@ -1401,7 +1414,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
tools=tools,
)
outstanding_responses.append(response)
-
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
@@ -1476,12 +1488,22 @@ class OpenAIChatCompletionToLlamaStackMixin:
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
+ total_prompt_tokens = 0
+ total_completion_tokens = 0
+ total_tokens = 0
+
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
+ # Aggregate usage data
+ if response.usage:
+ total_prompt_tokens += response.usage.prompt_tokens
+ total_completion_tokens += response.usage.completion_tokens
+ total_tokens += response.usage.total_tokens
+
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
@@ -1489,12 +1511,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
choices.append(choice)
+ usage = OpenAIChatCompletionUsage(
+ prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_tokens
+ )
+
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
+ usage=usage,
)
diff --git a/tests/integration/suites.py b/tests/integration/suites.py
index 231480447..e8b1b6973 100644
--- a/tests/integration/suites.py
+++ b/tests/integration/suites.py
@@ -108,6 +108,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
},
),
+ "fireworks": Setup(
+ name="fireworks",
+ description="Fireworks provider with a text model",
+ defaults={
+ "text_model": "fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct",
+ "vision_model": "fireworks/accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
+ "embedding_model": "nomic-ai/nomic-embed-text-v1.5",
+ },
+ ),
}