Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-04-14 12:15:44 -05:00 committed by GitHub
commit 488eb8f249
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 2102 additions and 164 deletions

View file

@ -3096,11 +3096,18 @@
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
"description": "OK", "description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenAIChatCompletion" "oneOf": [
{
"$ref": "#/components/schemas/OpenAIChatCompletion"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionChunk"
}
]
} }
} }
} }
@ -8857,7 +8864,17 @@
"description": "Must be \"assistant\" to identify this as the model's response" "description": "Must be \"assistant\" to identify this as the model's response"
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the model's response" "description": "The content of the model's response"
}, },
"name": { "name": {
@ -8867,9 +8884,9 @@
"tool_calls": { "tool_calls": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/ToolCall" "$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
}, },
"description": "List of tool calls. Each tool call is a ToolCall object." "description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8880,6 +8897,98 @@
"title": "OpenAIAssistantMessageParam", "title": "OpenAIAssistantMessageParam",
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request." "description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
}, },
"OpenAIChatCompletionContentPartImageParam": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "image_url",
"default": "image_url"
},
"image_url": {
"$ref": "#/components/schemas/OpenAIImageURL"
}
},
"additionalProperties": false,
"required": [
"type",
"image_url"
],
"title": "OpenAIChatCompletionContentPartImageParam"
},
"OpenAIChatCompletionContentPartParam": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
}
},
"OpenAIChatCompletionContentPartTextParam": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text",
"default": "text"
},
"text": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"text"
],
"title": "OpenAIChatCompletionContentPartTextParam"
},
"OpenAIChatCompletionToolCall": {
"type": "object",
"properties": {
"index": {
"type": "integer"
},
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "function",
"default": "function"
},
"function": {
"$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIChatCompletionToolCall"
},
"OpenAIChatCompletionToolCallFunction": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"arguments": {
"type": "string"
}
},
"additionalProperties": false,
"title": "OpenAIChatCompletionToolCallFunction"
},
"OpenAIDeveloperMessageParam": { "OpenAIDeveloperMessageParam": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -8890,7 +8999,17 @@
"description": "Must be \"developer\" to identify this as a developer message" "description": "Must be \"developer\" to identify this as a developer message"
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the developer message" "description": "The content of the developer message"
}, },
"name": { "name": {
@ -8906,6 +9025,66 @@
"title": "OpenAIDeveloperMessageParam", "title": "OpenAIDeveloperMessageParam",
"description": "A message from the developer in an OpenAI-compatible chat completion request." "description": "A message from the developer in an OpenAI-compatible chat completion request."
}, },
"OpenAIImageURL": {
"type": "object",
"properties": {
"url": {
"type": "string"
},
"detail": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"url"
],
"title": "OpenAIImageURL"
},
"OpenAIJSONSchema": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"strict": {
"type": "boolean"
},
"schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"name"
],
"title": "OpenAIJSONSchema"
},
"OpenAIMessageParam": { "OpenAIMessageParam": {
"oneOf": [ "oneOf": [
{ {
@ -8935,6 +9114,76 @@
} }
} }
}, },
"OpenAIResponseFormatJSONObject": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_object",
"default": "json_object"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIResponseFormatJSONObject"
},
"OpenAIResponseFormatJSONSchema": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"$ref": "#/components/schemas/OpenAIJSONSchema"
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
],
"title": "OpenAIResponseFormatJSONSchema"
},
"OpenAIResponseFormatParam": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseFormatText"
},
{
"$ref": "#/components/schemas/OpenAIResponseFormatJSONSchema"
},
{
"$ref": "#/components/schemas/OpenAIResponseFormatJSONObject"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/OpenAIResponseFormatText",
"json_schema": "#/components/schemas/OpenAIResponseFormatJSONSchema",
"json_object": "#/components/schemas/OpenAIResponseFormatJSONObject"
}
}
},
"OpenAIResponseFormatText": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text",
"default": "text"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIResponseFormatText"
},
"OpenAISystemMessageParam": { "OpenAISystemMessageParam": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -8945,7 +9194,17 @@
"description": "Must be \"system\" to identify this as a system message" "description": "Must be \"system\" to identify this as a system message"
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)." "description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)."
}, },
"name": { "name": {
@ -8975,7 +9234,17 @@
"description": "Unique identifier for the tool call this response is for" "description": "Unique identifier for the tool call this response is for"
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The response content from the tool" "description": "The response content from the tool"
} }
}, },
@ -8998,7 +9267,17 @@
"description": "Must be \"user\" to identify this as a user message" "description": "Must be \"user\" to identify this as a user message"
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the message, which can include text and other media" "description": "The content of the message, which can include text and other media"
}, },
"name": { "name": {
@ -9126,10 +9405,7 @@
"description": "(Optional) The penalty for repeated tokens" "description": "(Optional) The penalty for repeated tokens"
}, },
"response_format": { "response_format": {
"type": "object", "$ref": "#/components/schemas/OpenAIResponseFormatParam",
"additionalProperties": {
"type": "string"
},
"description": "(Optional) The response format to use" "description": "(Optional) The response format to use"
}, },
"seed": { "seed": {
@ -9306,6 +9582,46 @@
"title": "OpenAIChatCompletion", "title": "OpenAIChatCompletion",
"description": "Response from an OpenAI-compatible chat completion request." "description": "Response from an OpenAI-compatible chat completion request."
}, },
"OpenAIChatCompletionChunk": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The ID of the chat completion"
},
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChunkChoice"
},
"description": "List of choices"
},
"object": {
"type": "string",
"const": "chat.completion.chunk",
"default": "chat.completion.chunk",
"description": "The object type, which will be \"chat.completion.chunk\""
},
"created": {
"type": "integer",
"description": "The Unix timestamp in seconds when the chat completion was created"
},
"model": {
"type": "string",
"description": "The model that was used to generate the chat completion"
}
},
"additionalProperties": false,
"required": [
"id",
"choices",
"object",
"created",
"model"
],
"title": "OpenAIChatCompletionChunk",
"description": "Chunk from a streaming response to an OpenAI-compatible chat completion request."
},
"OpenAIChoice": { "OpenAIChoice": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9318,10 +9634,12 @@
"description": "The reason the model stopped generating" "description": "The reason the model stopped generating"
}, },
"index": { "index": {
"type": "integer" "type": "integer",
"description": "The index of the choice"
}, },
"logprobs": { "logprobs": {
"$ref": "#/components/schemas/OpenAIChoiceLogprobs" "$ref": "#/components/schemas/OpenAIChoiceLogprobs",
"description": "(Optional) The log probabilities for the tokens in the message"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -9333,6 +9651,33 @@
"title": "OpenAIChoice", "title": "OpenAIChoice",
"description": "A choice from an OpenAI-compatible chat completion response." "description": "A choice from an OpenAI-compatible chat completion response."
}, },
"OpenAIChoiceDelta": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "(Optional) The content of the delta"
},
"refusal": {
"type": "string",
"description": "(Optional) The refusal of the delta"
},
"role": {
"type": "string",
"description": "(Optional) The role of the delta"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
},
"description": "(Optional) The tool calls of the delta"
}
},
"additionalProperties": false,
"title": "OpenAIChoiceDelta",
"description": "A delta from an OpenAI-compatible chat completion streaming response."
},
"OpenAIChoiceLogprobs": { "OpenAIChoiceLogprobs": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9340,19 +9685,50 @@
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAITokenLogProb" "$ref": "#/components/schemas/OpenAITokenLogProb"
} },
"description": "(Optional) The log probabilities for the tokens in the message"
}, },
"refusal": { "refusal": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAITokenLogProb" "$ref": "#/components/schemas/OpenAITokenLogProb"
} },
"description": "(Optional) The log probabilities for the tokens in the message"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"title": "OpenAIChoiceLogprobs", "title": "OpenAIChoiceLogprobs",
"description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response." "description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."
}, },
"OpenAIChunkChoice": {
"type": "object",
"properties": {
"delta": {
"$ref": "#/components/schemas/OpenAIChoiceDelta",
"description": "The delta from the chunk"
},
"finish_reason": {
"type": "string",
"description": "The reason the model stopped generating"
},
"index": {
"type": "integer",
"description": "The index of the choice"
},
"logprobs": {
"$ref": "#/components/schemas/OpenAIChoiceLogprobs",
"description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
"required": [
"delta",
"finish_reason",
"index"
],
"title": "OpenAIChunkChoice",
"description": "A chunk choice from an OpenAI-compatible chat completion streaming response."
},
"OpenAITokenLogProb": { "OpenAITokenLogProb": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -2135,11 +2135,15 @@ paths:
post: post:
responses: responses:
'200': '200':
description: OK description: >-
Response from an OpenAI-compatible chat completion request. **OR** Chunk
from a streaming response to an OpenAI-compatible chat completion request.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIChatCompletion' oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletion'
- $ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -6073,7 +6077,11 @@ components:
description: >- description: >-
Must be "assistant" to identify this as the model's response Must be "assistant" to identify this as the model's response
content: content:
$ref: '#/components/schemas/InterleavedContent' oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The content of the model's response description: The content of the model's response
name: name:
type: string type: string
@ -6082,9 +6090,10 @@ components:
tool_calls: tool_calls:
type: array type: array
items: items:
$ref: '#/components/schemas/ToolCall' $ref: '#/components/schemas/OpenAIChatCompletionToolCall'
description: >- description: >-
List of tool calls. Each tool call is a ToolCall object. List of tool calls. Each tool call is an OpenAIChatCompletionToolCall
object.
additionalProperties: false additionalProperties: false
required: required:
- role - role
@ -6093,6 +6102,70 @@ components:
description: >- description: >-
A message containing the model's (assistant) response in an OpenAI-compatible A message containing the model's (assistant) response in an OpenAI-compatible
chat completion request. chat completion request.
"OpenAIChatCompletionContentPartImageParam":
type: object
properties:
type:
type: string
const: image_url
default: image_url
image_url:
$ref: '#/components/schemas/OpenAIImageURL'
additionalProperties: false
required:
- type
- image_url
title: >-
OpenAIChatCompletionContentPartImageParam
OpenAIChatCompletionContentPartParam:
oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
discriminator:
propertyName: type
mapping:
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
OpenAIChatCompletionContentPartTextParam:
type: object
properties:
type:
type: string
const: text
default: text
text:
type: string
additionalProperties: false
required:
- type
- text
title: OpenAIChatCompletionContentPartTextParam
OpenAIChatCompletionToolCall:
type: object
properties:
index:
type: integer
id:
type: string
type:
type: string
const: function
default: function
function:
$ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction'
additionalProperties: false
required:
- type
title: OpenAIChatCompletionToolCall
OpenAIChatCompletionToolCallFunction:
type: object
properties:
name:
type: string
arguments:
type: string
additionalProperties: false
title: OpenAIChatCompletionToolCallFunction
OpenAIDeveloperMessageParam: OpenAIDeveloperMessageParam:
type: object type: object
properties: properties:
@ -6103,7 +6176,11 @@ components:
description: >- description: >-
Must be "developer" to identify this as a developer message Must be "developer" to identify this as a developer message
content: content:
$ref: '#/components/schemas/InterleavedContent' oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The content of the developer message description: The content of the developer message
name: name:
type: string type: string
@ -6116,6 +6193,40 @@ components:
title: OpenAIDeveloperMessageParam title: OpenAIDeveloperMessageParam
description: >- description: >-
A message from the developer in an OpenAI-compatible chat completion request. A message from the developer in an OpenAI-compatible chat completion request.
OpenAIImageURL:
type: object
properties:
url:
type: string
detail:
type: string
additionalProperties: false
required:
- url
title: OpenAIImageURL
OpenAIJSONSchema:
type: object
properties:
name:
type: string
description:
type: string
strict:
type: boolean
schema:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- name
title: OpenAIJSONSchema
OpenAIMessageParam: OpenAIMessageParam:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIUserMessageParam' - $ref: '#/components/schemas/OpenAIUserMessageParam'
@ -6131,6 +6242,53 @@ components:
assistant: '#/components/schemas/OpenAIAssistantMessageParam' assistant: '#/components/schemas/OpenAIAssistantMessageParam'
tool: '#/components/schemas/OpenAIToolMessageParam' tool: '#/components/schemas/OpenAIToolMessageParam'
developer: '#/components/schemas/OpenAIDeveloperMessageParam' developer: '#/components/schemas/OpenAIDeveloperMessageParam'
OpenAIResponseFormatJSONObject:
type: object
properties:
type:
type: string
const: json_object
default: json_object
additionalProperties: false
required:
- type
title: OpenAIResponseFormatJSONObject
OpenAIResponseFormatJSONSchema:
type: object
properties:
type:
type: string
const: json_schema
default: json_schema
json_schema:
$ref: '#/components/schemas/OpenAIJSONSchema'
additionalProperties: false
required:
- type
- json_schema
title: OpenAIResponseFormatJSONSchema
OpenAIResponseFormatParam:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseFormatText'
- $ref: '#/components/schemas/OpenAIResponseFormatJSONSchema'
- $ref: '#/components/schemas/OpenAIResponseFormatJSONObject'
discriminator:
propertyName: type
mapping:
text: '#/components/schemas/OpenAIResponseFormatText'
json_schema: '#/components/schemas/OpenAIResponseFormatJSONSchema'
json_object: '#/components/schemas/OpenAIResponseFormatJSONObject'
OpenAIResponseFormatText:
type: object
properties:
type:
type: string
const: text
default: text
additionalProperties: false
required:
- type
title: OpenAIResponseFormatText
OpenAISystemMessageParam: OpenAISystemMessageParam:
type: object type: object
properties: properties:
@ -6141,7 +6299,11 @@ components:
description: >- description: >-
Must be "system" to identify this as a system message Must be "system" to identify this as a system message
content: content:
$ref: '#/components/schemas/InterleavedContent' oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: >- description: >-
The content of the "system prompt". If multiple system messages are provided, The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other they are concatenated. The underlying Llama Stack code may also add other
@ -6171,7 +6333,11 @@ components:
description: >- description: >-
Unique identifier for the tool call this response is for Unique identifier for the tool call this response is for
content: content:
$ref: '#/components/schemas/InterleavedContent' oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The response content from the tool description: The response content from the tool
additionalProperties: false additionalProperties: false
required: required:
@ -6192,7 +6358,11 @@ components:
description: >- description: >-
Must be "user" to identify this as a user message Must be "user" to identify this as a user message
content: content:
$ref: '#/components/schemas/InterleavedContent' oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: >- description: >-
The content of the message, which can include text and other media The content of the message, which can include text and other media
name: name:
@ -6278,9 +6448,7 @@ components:
description: >- description: >-
(Optional) The penalty for repeated tokens (Optional) The penalty for repeated tokens
response_format: response_format:
type: object $ref: '#/components/schemas/OpenAIResponseFormatParam'
additionalProperties:
type: string
description: (Optional) The response format to use description: (Optional) The response format to use
seed: seed:
type: integer type: integer
@ -6386,6 +6554,41 @@ components:
title: OpenAIChatCompletion title: OpenAIChatCompletion
description: >- description: >-
Response from an OpenAI-compatible chat completion request. Response from an OpenAI-compatible chat completion request.
OpenAIChatCompletionChunk:
type: object
properties:
id:
type: string
description: The ID of the chat completion
choices:
type: array
items:
$ref: '#/components/schemas/OpenAIChunkChoice'
description: List of choices
object:
type: string
const: chat.completion.chunk
default: chat.completion.chunk
description: >-
The object type, which will be "chat.completion.chunk"
created:
type: integer
description: >-
The Unix timestamp in seconds when the chat completion was created
model:
type: string
description: >-
The model that was used to generate the chat completion
additionalProperties: false
required:
- id
- choices
- object
- created
- model
title: OpenAIChatCompletionChunk
description: >-
Chunk from a streaming response to an OpenAI-compatible chat completion request.
OpenAIChoice: OpenAIChoice:
type: object type: object
properties: properties:
@ -6397,8 +6600,11 @@ components:
description: The reason the model stopped generating description: The reason the model stopped generating
index: index:
type: integer type: integer
description: The index of the choice
logprobs: logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs' $ref: '#/components/schemas/OpenAIChoiceLogprobs'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false additionalProperties: false
required: required:
- message - message
@ -6407,6 +6613,27 @@ components:
title: OpenAIChoice title: OpenAIChoice
description: >- description: >-
A choice from an OpenAI-compatible chat completion response. A choice from an OpenAI-compatible chat completion response.
OpenAIChoiceDelta:
type: object
properties:
content:
type: string
description: (Optional) The content of the delta
refusal:
type: string
description: (Optional) The refusal of the delta
role:
type: string
description: (Optional) The role of the delta
tool_calls:
type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionToolCall'
description: (Optional) The tool calls of the delta
additionalProperties: false
title: OpenAIChoiceDelta
description: >-
A delta from an OpenAI-compatible chat completion streaming response.
OpenAIChoiceLogprobs: OpenAIChoiceLogprobs:
type: object type: object
properties: properties:
@ -6414,15 +6641,43 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/OpenAITokenLogProb' $ref: '#/components/schemas/OpenAITokenLogProb'
description: >-
(Optional) The log probabilities for the tokens in the message
refusal: refusal:
type: array type: array
items: items:
$ref: '#/components/schemas/OpenAITokenLogProb' $ref: '#/components/schemas/OpenAITokenLogProb'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false additionalProperties: false
title: OpenAIChoiceLogprobs title: OpenAIChoiceLogprobs
description: >- description: >-
The log probabilities for the tokens in the message from an OpenAI-compatible The log probabilities for the tokens in the message from an OpenAI-compatible
chat completion response. chat completion response.
OpenAIChunkChoice:
type: object
properties:
delta:
$ref: '#/components/schemas/OpenAIChoiceDelta'
description: The delta from the chunk
finish_reason:
type: string
description: The reason the model stopped generating
index:
type: integer
description: The index of the choice
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false
required:
- delta
- finish_reason
- index
title: OpenAIChunkChoice
description: >-
A chunk choice from an OpenAI-compatible chat completion streaming response.
OpenAITokenLogProb: OpenAITokenLogProb:
type: object type: object
properties: properties:

View file

@ -43,7 +43,9 @@ The following models are available by default:
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)` - `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)` - `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)` - `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)` - `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -18,7 +18,7 @@ from typing import (
) )
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated, TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
@ -442,6 +442,37 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]] embeddings: List[List[float]]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
type: Literal["text"] = "text"
text: str
@json_schema_type
class OpenAIImageURL(BaseModel):
url: str
detail: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionContentPartImageParam(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenAIImageURL
OpenAIChatCompletionContentPartParam = Annotated[
Union[
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionContentPartImageParam,
],
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
@json_schema_type @json_schema_type
class OpenAIUserMessageParam(BaseModel): class OpenAIUserMessageParam(BaseModel):
"""A message from the user in an OpenAI-compatible chat completion request. """A message from the user in an OpenAI-compatible chat completion request.
@ -452,7 +483,7 @@ class OpenAIUserMessageParam(BaseModel):
""" """
role: Literal["user"] = "user" role: Literal["user"] = "user"
content: InterleavedContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: Optional[str] = None
@ -466,10 +497,24 @@ class OpenAISystemMessageParam(BaseModel):
""" """
role: Literal["system"] = "system" role: Literal["system"] = "system"
content: InterleavedContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
index: Optional[int] = None
id: Optional[str] = None
type: Literal["function"] = "function"
function: Optional[OpenAIChatCompletionToolCallFunction] = None
@json_schema_type @json_schema_type
class OpenAIAssistantMessageParam(BaseModel): class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request. """A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
@ -477,13 +522,13 @@ class OpenAIAssistantMessageParam(BaseModel):
:param role: Must be "assistant" to identify this as the model's response :param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response :param content: The content of the model's response
:param name: (Optional) The name of the assistant message participant. :param name: (Optional) The name of the assistant message participant.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object. :param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
""" """
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: InterleavedContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
@json_schema_type @json_schema_type
@ -497,7 +542,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
tool_call_id: str tool_call_id: str
content: InterleavedContent content: OpenAIChatCompletionMessageContent
@json_schema_type @json_schema_type
@ -510,7 +555,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
""" """
role: Literal["developer"] = "developer" role: Literal["developer"] = "developer"
content: InterleavedContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: Optional[str] = None
@ -527,6 +572,46 @@ OpenAIMessageParam = Annotated[
register_schema(OpenAIMessageParam, name="OpenAIMessageParam") register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@json_schema_type
class OpenAIResponseFormatText(BaseModel):
type: Literal["text"] = "text"
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: Optional[str] = None
strict: Optional[bool] = None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: Optional[Dict[str, Any]] = None
@json_schema_type
class OpenAIResponseFormatJSONSchema(BaseModel):
type: Literal["json_schema"] = "json_schema"
json_schema: OpenAIJSONSchema
@json_schema_type
class OpenAIResponseFormatJSONObject(BaseModel):
type: Literal["json_object"] = "json_object"
OpenAIResponseFormatParam = Annotated[
Union[
OpenAIResponseFormatText,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatJSONObject,
],
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type @json_schema_type
class OpenAITopLogProb(BaseModel): class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response. """The top log probability for a token from an OpenAI-compatible chat completion response.
@ -561,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
class OpenAIChoiceLogprobs(BaseModel): class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response. """The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
:content: (Optional) The log probabilities for the tokens in the message :param content: (Optional) The log probabilities for the tokens in the message
:refusal: (Optional) The log probabilities for the tokens in the message :param refusal: (Optional) The log probabilities for the tokens in the message
""" """
content: Optional[List[OpenAITokenLogProb]] = None content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None refusal: Optional[List[OpenAITokenLogProb]] = None
@json_schema_type
class OpenAIChoiceDelta(BaseModel):
"""A delta from an OpenAI-compatible chat completion streaming response.
:param content: (Optional) The content of the delta
:param refusal: (Optional) The refusal of the delta
:param role: (Optional) The role of the delta
:param tool_calls: (Optional) The tool calls of the delta
"""
content: Optional[str] = None
refusal: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
@json_schema_type
class OpenAIChunkChoice(BaseModel):
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
:param delta: The delta from the chunk
:param finish_reason: The reason the model stopped generating
:param index: The index of the choice
:param logprobs: (Optional) The log probabilities for the tokens in the message
"""
delta: OpenAIChoiceDelta
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type @json_schema_type
class OpenAIChoice(BaseModel): class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response. """A choice from an OpenAI-compatible chat completion response.
:param message: The message from the model :param message: The message from the model
:param finish_reason: The reason the model stopped generating :param finish_reason: The reason the model stopped generating
:index: The index of the choice :param index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the message :param logprobs: (Optional) The log probabilities for the tokens in the message
""" """
message: OpenAIMessageParam message: OpenAIMessageParam
@ -603,6 +720,24 @@ class OpenAIChatCompletion(BaseModel):
model: str model: str
@json_schema_type
class OpenAIChatCompletionChunk(BaseModel):
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
@json_schema_type @json_schema_type
class OpenAICompletionLogprobs(BaseModel): class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response. """The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
@ -872,7 +1007,7 @@ class Inference(Protocol):
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -883,7 +1018,7 @@ class Inference(Protocol):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -38,7 +38,13 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
@ -531,7 +537,7 @@ class InferenceRouter(Inference):
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -542,7 +548,7 @@ class InferenceRouter(Inference):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
) )

View file

@ -204,7 +204,9 @@ class ToolUtils:
return None return None
elif is_json(message_body): elif is_json(message_body):
response = json.loads(message_body) response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response): if ("type" in response and response["type"] == "function") or (
"name" in response and "parameters" in response
):
function_name = response["name"] function_name = response["name"]
args = response["parameters"] args = response["parameters"]
return function_name, args return function_name, args

View file

@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl( class MetaReferenceInferenceImpl(
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
Inference, Inference,
ModelsProtocolPrivate, ModelsProtocolPrivate,

View file

@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
) )
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl(
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
Inference, Inference,
ModelsProtocolPrivate, ModelsProtocolPrivate,

View file

@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate, ModelsProtocolPrivate,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_stop_reason, get_stop_reason,
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
@ -176,8 +176,8 @@ def _convert_sampling_params(
class VLLMInferenceImpl( class VLLMInferenceImpl(
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
""" """

View file

@ -3,13 +3,14 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime, timezone from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
AlgorithmConfig, AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig, DPOAlignmentConfig,
JobStatus, JobStatus,
ListPostTrainingJobsResponse, ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice, LoraFinetuningSingleDevice,
) )
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl: class TorchtunePostTrainingImpl:
def __init__( def __init__(
self, self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling async def shutdown(self) -> None:
self.jobs = {} await self._scheduler.shutdown()
self.checkpoints_dict = {}
async def shutdown(self): @staticmethod
pass def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig], algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob: ) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig): if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice( recipe = LoraFinetuningSingleDevice(
self.config, self.config,
job_uuid, job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api, self.datasetio_api,
self.datasets_api, self.datasets_api,
) )
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup() await recipe.setup()
resources_allocated, checkpoints = await recipe.train() resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
job_status_response.resources_allocated = resources_allocated for checkpoint in checkpoints:
job_status_response.checkpoints = checkpoints artifact = self._checkpoint_to_artifact(checkpoint)
job_status_response.status = JobStatus.completed on_artifact_collected_cb(artifact)
job_status_response.completed_at = datetime.now(timezone.utc)
except Exception: on_status_change_cb(SchedulerJobStatus.completed)
job_status_response.status = JobStatus.failed on_log_message_cb("Lora finetuning completed")
raise
else: else:
raise NotImplementedError() raise NotImplementedError()
return post_training_job job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize( async def preference_optimize(
self, self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None) job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict: job = self._scheduler.get_job(job_uuid)
checkpoints = self.checkpoints_dict.get(job_uuid, []) return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None

View file

@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options, get_sampling_strategy_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
class BedrockInferenceAdapter( class BedrockInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter( class CerebrasInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: CerebrasImplConfig) -> None: def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -61,8 +61,8 @@ model_entries = [
class DatabricksInferenceAdapter( class DatabricksInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=model_entries)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
prompt=prompt, prompt=prompt,
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
return await self._get_openai_client().completions.create(**params) return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion( async def openai_chat_completion(
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
function_call=function_call, function_call=function_call,
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
return await self._get_openai_client().chat.completions.create(**params)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,8 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,
OpenAIChunkChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAISystemMessageParam,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key", provider_data_api_key_field="groq_api_key",
) )
self.config = config self.config = config
self._openai_client = None
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.config.api_key,
)
return self._openai_client
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Groq does not support json_schema response format, so we need to convert it to json_object
if response_format and response_format.type == "json_schema":
response_format.type = "json_object"
schema = response_format.json_schema.get("schema", {})
response_format.json_schema = None
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
if messages and messages[0].role == "system":
messages[0].content = messages[0].content + json_instructions
else:
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
# Groq returns a 400 error if tools are provided but none are called
# So, set tool_choice to "required" to attempt to force a call
if tools and (not tool_choice or tool_choice == "auto"):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
# Groq does not support streaming requests that set response_format
fake_stream = False
if stream and response_format:
params["stream"] = False
fake_stream = True
response = await self._get_openai_client().chat.completions.create(**params)
if fake_stream:
chunk_choices = []
for choice in response.choices:
delta = OpenAIChoiceDelta(
content=choice.message.content,
role=choice.message.role,
tool_calls=choice.message.tool_calls,
)
chunk_choice = OpenAIChunkChoice(
delta=delta,
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=None,
)
chunk_choices.append(chunk_choice)
chunk = OpenAIChatCompletionChunk(
id=response.id,
choices=chunk_choices,
object="chat.completion.chunk",
created=response.created,
model=response.model,
)
async def _fake_stream_generator():
yield chunk
return _fake_stream_generator()
else:
return response

View file

@ -39,8 +39,16 @@ MODEL_ENTRIES = [
"groq/llama-4-scout-17b-16e-instruct", "groq/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
), ),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct", "groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] ]

View file

@ -35,7 +35,13 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
@ -329,7 +335,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -340,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model) provider_model_id = self.get_provider_model_id(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient
@ -39,7 +39,13 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
@ -337,6 +343,12 @@ class OllamaInferenceAdapter(
response = await self.client.list() response = await self.client.list()
available_models = [m["model"] for m in response["models"]] available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models: if model.provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise ValueError( raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
) )
@ -408,7 +420,7 @@ class OllamaInferenceAdapter(
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -419,7 +431,7 @@ class OllamaInferenceAdapter(
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
params = { params = {
k: v k: v

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient from llama_stack_client import AsyncLlamaStackClient
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)

View file

@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)

View file

@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
class _HfAdapter( class _HfAdapter(
Inference, Inference,
OpenAIChatCompletionUnsupportedMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionUnsupportedMixin, OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
client: AsyncInferenceClient client: AsyncInferenceClient

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI from openai import AsyncOpenAI
from together import AsyncTogether from together import AsyncTogether
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
if params.get("stream", True):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging import logging
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
@ -487,7 +492,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -498,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,

View file

@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
guided_choice: Optional[List[str]] = None, guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self._get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
prompt=prompt, prompt=prompt,
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
guided_choice=guided_choice, guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
) )
return litellm.text_completion(**params) return await litellm.atext_completion(**params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
messages=messages, messages=messages,
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
return litellm.completion(**params) return await litellm.acompletion(**params)
async def batch_completion( async def batch_completion(
self, self,

View file

@ -8,7 +8,7 @@ import logging
import time import time
import uuid import uuid
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
from openai.types.chat.chat_completion import ( from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
) )
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import ( from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL, ImageURL as OpenAIImageURL,
) )
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
TextContentItem, TextContentItem,
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy, TopPSamplingStrategy,
UserMessage, UserMessage,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice from llama_stack.apis.inference.inference import (
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolConfig,
)
from llama_stack.apis.inference.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolParamDefinition,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out return out
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
"""
Convert a StopReason to an OpenAI chat completion finish_reason.
"""
return {
StopReason.end_of_turn: "stop",
StopReason.end_of_message: "tool_calls",
StopReason.out_of_tokens: "length",
}.get(stop_reason, "stop")
def _convert_openai_finish_reason(finish_reason: str) -> StopReason: def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
""" """
Convert an OpenAI chat completion finish_reason to a StopReason. Convert an OpenAI chat completion finish_reason to a StopReason.
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn) }.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
tool_config.tool_choice = tool_choice
return tool_config
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
lls_tools = []
if not tools:
return lls_tools
for tool in tools:
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool_params = {}
if tool_params is not None:
tool_param_properties = tool_params.get("properties", {})
for tool_param_key, tool_param_value in tool_param_properties.items():
tool_param_def = ToolParamDefinition(
param_type=tool_param_value.get("type", None),
description=tool_param_value.get("description", None),
)
lls_tool_params[tool_param_key] = tool_param_def
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
parameters=lls_tool_params,
)
lls_tools.append(lls_tool)
return lls_tools
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format = dict(response_format)
if response_format.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema",
json_schema=response_format.get("json_schema", {}).get("schema", ""),
)
return None
def _convert_openai_tool_calls( def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall], tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]: ) -> List[ToolCall]:
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
return sampling_params return sampling_params
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
# Llama Stack messages and OpenAI messages are similar, but not identical.
lls_messages = []
for message in messages:
lls_message = dict(message)
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
tool_call_id = lls_message.pop("tool_call_id", None)
if tool_call_id:
lls_message["call_id"] = tool_call_id
content = lls_message.get("content", None)
if isinstance(content, list):
lls_content = []
for item in content:
# items can either by pydantic models or dicts here...
item = dict(item)
if item.get("type", "") == "image_url":
lls_item = ImageContentItem(
type="image",
image=URL(uri=item.get("image_url", {}).get("url", "")),
)
elif item.get("type", "") == "text":
lls_item = TextContentItem(
type="text",
text=item.get("text", ""),
)
lls_content.append(lls_item)
lls_message["content"] = lls_content
lls_messages.append(lls_message)
return lls_messages
def convert_openai_chat_completion_choice( def convert_openai_chat_completion_choice(
choice: OpenAIChoice, choice: OpenAIChoice,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
async def prepare_openai_completion_params(**params): async def prepare_openai_completion_params(**params):
completion_params = {k: v for k, v in params.items() if v is not None} async def _prepare_value(value: Any) -> Any:
new_value = value
if isinstance(value, list):
new_value = [await _prepare_value(v) for v in value]
elif isinstance(value, dict):
new_value = {k: await _prepare_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
new_value = value.model_dump(exclude_none=True)
return new_value
completion_params = {}
for k, v in params.items():
if v is not None:
completion_params[k] = await _prepare_value(v)
return completion_params return completion_params
class OpenAICompletionUnsupportedMixin: class OpenAICompletionToLlamaStackMixin:
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
choices = [] choices = []
# "n" is the number of completions to generate per prompt # "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n): for _i in range(0, n):
# and we may have multiple prompts, if batching was used # and we may have multiple prompts, if batching was used
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
index = len(choices) index = len(choices)
text = result.content text = result.content
finish_reason = _convert_openai_finish_reason(result.stop_reason) finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice( choice = OpenAICompletionChoice(
index=index, index=index,
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
) )
class OpenAIChatCompletionUnsupportedMixin: class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
n: Optional[int] = None, n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None, response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages = _convert_openai_request_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
outstanding_responses = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
response = self.chat_completion(
model_id=model,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
tool_config=tool_config,
tools=tools,
)
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses
)
async def _process_stream_response(
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk(
id=id,
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name, arguments=tool_call.arguments_json
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion: ) -> OpenAIChatCompletion:
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion") choices = []
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
finish_reason=finish_reason,
)
choices.append(choice)
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
)

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: Dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = Tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()

View file

@ -391,6 +391,16 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq provider_id: groq
@ -401,6 +411,16 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_id: sambanova/Meta-Llama-3.1-8B-Instruct
provider_id: sambanova provider_id: sambanova

View file

@ -158,6 +158,16 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq provider_id: groq
@ -168,6 +178,16 @@ models:
provider_id: groq provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2

View file

@ -474,6 +474,16 @@ models:
provider_id: groq-openai-compat provider_id: groq-openai-compat
provider_model_id: groq/llama-4-scout-17b-16e-instruct provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq-openai-compat provider_id: groq-openai-compat
@ -484,6 +494,16 @@ models:
provider_id: groq-openai-compat provider_id: groq-openai-compat
provider_model_id: groq/llama-4-maverick-17b-128e-instruct provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: sambanova/Meta-Llama-3.1-8B-Instruct model_id: sambanova/Meta-Llama-3.1-8B-Instruct
provider_id: sambanova-openai-compat provider_id: sambanova-openai-compat

View file

@ -115,7 +115,7 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
stream=True, stream=True,
max_tokens=50, max_tokens=50,
) )
streamed_content = [chunk.choices[0].text for chunk in response] streamed_content = [chunk.choices[0].text or "" for chunk in response]
content_str = "".join(streamed_content).lower().strip() content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10 assert len(content_str) > 10

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
for _ in range(10):
if job.status == JobStatus.failed:
break
await asyncio.sleep(0.1)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: FIREWORKS_API_KEY
models:
- fireworks/llama-v3p3-70b-instruct
- fireworks/llama4-scout-instruct-basic
- fireworks/llama4-maverick-instruct-basic
model_display_names:
fireworks/llama-v3p3-70b-instruct: Llama-3.3-70B-Instruct
fireworks/llama4-scout-instruct-basic: Llama-4-Scout-Instruct
fireworks/llama4-maverick-instruct-basic: Llama-4-Maverick-Instruct
test_exclusions:
fireworks/llama-v3p3-70b-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: GROQ_API_KEY
models:
- groq/llama-3.3-70b-versatile
- groq/llama-4-scout-17b-16e-instruct
- groq/llama-4-maverick-17b-128e-instruct
model_display_names:
groq/llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
groq/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
groq/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
groq/llama-3.3-70b-versatile:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -2,12 +2,12 @@ base_url: https://api.groq.com/openai/v1
api_key_var: GROQ_API_KEY api_key_var: GROQ_API_KEY
models: models:
- llama-3.3-70b-versatile - llama-3.3-70b-versatile
- llama-4-scout-17b-16e-instruct - meta-llama/llama-4-scout-17b-16e-instruct
- llama-4-maverick-17b-128e-instruct - meta-llama/llama-4-maverick-17b-128e-instruct
model_display_names: model_display_names:
llama-3.3-70b-versatile: Llama-3.3-70B-Instruct llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct meta-llama/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct meta-llama/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions: test_exclusions:
llama-3.3-70b-versatile: llama-3.3-70b-versatile:
- test_chat_non_streaming_image - test_chat_non_streaming_image

View file

@ -0,0 +1,9 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: OPENAI_API_KEY
models:
- openai/gpt-4o
- openai/gpt-4o-mini
model_display_names:
openai/gpt-4o: gpt-4o
openai/gpt-4o-mini: gpt-4o-mini
test_exclusions: {}

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: TOGETHER_API_KEY
models:
- together/meta-llama/Llama-3.3-70B-Instruct-Turbo
- together/meta-llama/Llama-4-Scout-17B-16E-Instruct
- together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_display_names:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo: Llama-3.3-70B-Instruct
together/meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8: Llama-4-Maverick-Instruct
test_exclusions:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -67,7 +67,17 @@ RESULTS_DIR.mkdir(exist_ok=True)
# Maximum number of test result files to keep per provider # Maximum number of test result files to keep per provider
MAX_RESULTS_PER_PROVIDER = 1 MAX_RESULTS_PER_PROVIDER = 1
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"] PROVIDER_ORDER = [
"together",
"fireworks",
"groq",
"cerebras",
"openai",
"together-llama-stack",
"fireworks-llama-stack",
"groq-llama-stack",
"openai-llama-stack",
]
VERIFICATION_CONFIG = _load_all_verification_configs() VERIFICATION_CONFIG = _load_all_verification_configs()

View file

@ -0,0 +1,146 @@
version: '2'
image_name: openai-api-verification
apis:
- inference
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: openai
provider_type: remote::openai
config:
url: https://api.openai.com/v1
api_key: ${env.OPENAI_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/faiss_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/registry.db
models:
- metadata: {}
model_id: together/meta-llama/Llama-3.3-70B-Instruct-Turbo
provider_id: together
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: together
provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
provider_id: together
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: fireworks/llama-v3p3-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: fireworks/llama4-scout-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: fireworks/llama4-maverick-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata: {}
model_id: groq/llama-3.3-70b-versatile
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: groq/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: openai/gpt-4o-mini
provider_id: openai
provider_model_id: openai/gpt-4o-mini
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -99,6 +99,9 @@ def model_mapping(provider, providers_model_mapping):
@pytest.fixture @pytest.fixture
def openai_client(base_url, api_key): def openai_client(base_url, api_key):
# Simplify running against a local Llama Stack
if "localhost" in base_url and not api_key:
api_key = "empty"
return OpenAI( return OpenAI(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,