mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
test
# What does this PR do? ## Test Plan
This commit is contained in:
parent
f50ce11a3b
commit
4a3d1e33f8
31 changed files with 727 additions and 892 deletions
|
@ -11,6 +11,7 @@ from pathlib import Path
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
|
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||||
from llama_stack.core.resolver import api_protocol_map
|
from llama_stack.core.resolver import api_protocol_map
|
||||||
|
|
||||||
|
@ -205,6 +206,14 @@ def _validate_has_return_in_docstring(method) -> str | None:
|
||||||
def _validate_has_params_in_docstring(method) -> str | None:
|
def _validate_has_params_in_docstring(method) -> str | None:
|
||||||
source = inspect.getsource(method)
|
source = inspect.getsource(method)
|
||||||
sig = inspect.signature(method)
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||||
|
if len(params_list) == 1:
|
||||||
|
param = params_list[0]
|
||||||
|
param_type = param.annotation
|
||||||
|
if issubclass(param_type, BaseModel):
|
||||||
|
return
|
||||||
|
|
||||||
# Only check if the method has more than one parameter
|
# Only check if the method has more than one parameter
|
||||||
if len(sig.parameters) > 1 and ":param" not in source:
|
if len(sig.parameters) > 1 and ":param" not in source:
|
||||||
return "does not have a ':param' in its docstring"
|
return "does not have a ':param' in its docstring"
|
||||||
|
|
33
docs/static/deprecated-llama-stack-spec.html
vendored
33
docs/static/deprecated-llama-stack-spec.html
vendored
|
@ -7716,7 +7716,8 @@
|
||||||
"model",
|
"model",
|
||||||
"messages"
|
"messages"
|
||||||
],
|
],
|
||||||
"title": "OpenaiChatCompletionRequest"
|
"title": "OpenaiChatCompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
|
||||||
},
|
},
|
||||||
"OpenAIChatCompletion": {
|
"OpenAIChatCompletion": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7900,7 +7901,7 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAICompletionWithInputMessages"
|
"title": "OpenAICompletionWithInputMessages"
|
||||||
},
|
},
|
||||||
"OpenaiCompletionRequest": {
|
"OpenAICompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model": {
|
||||||
|
@ -8031,18 +8032,20 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) The user to use."
|
"description": "(Optional) The user to use."
|
||||||
},
|
},
|
||||||
|
"suffix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The suffix that should be appended to the completion."
|
||||||
|
},
|
||||||
"guided_choice": {
|
"guided_choice": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
},
|
||||||
|
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
|
||||||
},
|
},
|
||||||
"prompt_logprobs": {
|
"prompt_logprobs": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
},
|
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
|
||||||
"suffix": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "(Optional) The suffix that should be appended to the completion."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -8050,6 +8053,20 @@
|
||||||
"model",
|
"model",
|
||||||
"prompt"
|
"prompt"
|
||||||
],
|
],
|
||||||
|
"title": "OpenAICompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible completion endpoint."
|
||||||
|
},
|
||||||
|
"OpenaiCompletionRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"params": {
|
||||||
|
"$ref": "#/components/schemas/OpenAICompletionRequest"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"params"
|
||||||
|
],
|
||||||
"title": "OpenaiCompletionRequest"
|
"title": "OpenaiCompletionRequest"
|
||||||
},
|
},
|
||||||
"OpenAICompletion": {
|
"OpenAICompletion": {
|
||||||
|
|
33
docs/static/deprecated-llama-stack-spec.yaml
vendored
33
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
@ -5671,6 +5671,8 @@ components:
|
||||||
- model
|
- model
|
||||||
- messages
|
- messages
|
||||||
title: OpenaiChatCompletionRequest
|
title: OpenaiChatCompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible chat completion endpoint.
|
||||||
OpenAIChatCompletion:
|
OpenAIChatCompletion:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5824,7 +5826,7 @@ components:
|
||||||
- model
|
- model
|
||||||
- input_messages
|
- input_messages
|
||||||
title: OpenAICompletionWithInputMessages
|
title: OpenAICompletionWithInputMessages
|
||||||
OpenaiCompletionRequest:
|
OpenAICompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model:
|
||||||
|
@ -5912,20 +5914,37 @@ components:
|
||||||
user:
|
user:
|
||||||
type: string
|
type: string
|
||||||
description: (Optional) The user to use.
|
description: (Optional) The user to use.
|
||||||
guided_choice:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
prompt_logprobs:
|
|
||||||
type: integer
|
|
||||||
suffix:
|
suffix:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) The suffix that should be appended to the completion.
|
(Optional) The suffix that should be appended to the completion.
|
||||||
|
guided_choice:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for guided generation with a list of
|
||||||
|
choices.
|
||||||
|
prompt_logprobs:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for number of log probabilities to
|
||||||
|
return for prompt tokens.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
- prompt
|
- prompt
|
||||||
|
title: OpenAICompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible completion endpoint.
|
||||||
|
OpenaiCompletionRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
params:
|
||||||
|
$ref: '#/components/schemas/OpenAICompletionRequest'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- params
|
||||||
title: OpenaiCompletionRequest
|
title: OpenaiCompletionRequest
|
||||||
OpenAICompletion:
|
OpenAICompletion:
|
||||||
type: object
|
type: object
|
||||||
|
|
33
docs/static/llama-stack-spec.html
vendored
33
docs/static/llama-stack-spec.html
vendored
|
@ -5212,7 +5212,8 @@
|
||||||
"model",
|
"model",
|
||||||
"messages"
|
"messages"
|
||||||
],
|
],
|
||||||
"title": "OpenaiChatCompletionRequest"
|
"title": "OpenaiChatCompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
|
||||||
},
|
},
|
||||||
"OpenAIChatCompletion": {
|
"OpenAIChatCompletion": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -5396,7 +5397,7 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAICompletionWithInputMessages"
|
"title": "OpenAICompletionWithInputMessages"
|
||||||
},
|
},
|
||||||
"OpenaiCompletionRequest": {
|
"OpenAICompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model": {
|
||||||
|
@ -5527,18 +5528,20 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) The user to use."
|
"description": "(Optional) The user to use."
|
||||||
},
|
},
|
||||||
|
"suffix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The suffix that should be appended to the completion."
|
||||||
|
},
|
||||||
"guided_choice": {
|
"guided_choice": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
},
|
||||||
|
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
|
||||||
},
|
},
|
||||||
"prompt_logprobs": {
|
"prompt_logprobs": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
},
|
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
|
||||||
"suffix": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "(Optional) The suffix that should be appended to the completion."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -5546,6 +5549,20 @@
|
||||||
"model",
|
"model",
|
||||||
"prompt"
|
"prompt"
|
||||||
],
|
],
|
||||||
|
"title": "OpenAICompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible completion endpoint."
|
||||||
|
},
|
||||||
|
"OpenaiCompletionRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"params": {
|
||||||
|
"$ref": "#/components/schemas/OpenAICompletionRequest"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"params"
|
||||||
|
],
|
||||||
"title": "OpenaiCompletionRequest"
|
"title": "OpenaiCompletionRequest"
|
||||||
},
|
},
|
||||||
"OpenAICompletion": {
|
"OpenAICompletion": {
|
||||||
|
|
33
docs/static/llama-stack-spec.yaml
vendored
33
docs/static/llama-stack-spec.yaml
vendored
|
@ -3920,6 +3920,8 @@ components:
|
||||||
- model
|
- model
|
||||||
- messages
|
- messages
|
||||||
title: OpenaiChatCompletionRequest
|
title: OpenaiChatCompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible chat completion endpoint.
|
||||||
OpenAIChatCompletion:
|
OpenAIChatCompletion:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -4073,7 +4075,7 @@ components:
|
||||||
- model
|
- model
|
||||||
- input_messages
|
- input_messages
|
||||||
title: OpenAICompletionWithInputMessages
|
title: OpenAICompletionWithInputMessages
|
||||||
OpenaiCompletionRequest:
|
OpenAICompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model:
|
||||||
|
@ -4161,20 +4163,37 @@ components:
|
||||||
user:
|
user:
|
||||||
type: string
|
type: string
|
||||||
description: (Optional) The user to use.
|
description: (Optional) The user to use.
|
||||||
guided_choice:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
prompt_logprobs:
|
|
||||||
type: integer
|
|
||||||
suffix:
|
suffix:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) The suffix that should be appended to the completion.
|
(Optional) The suffix that should be appended to the completion.
|
||||||
|
guided_choice:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for guided generation with a list of
|
||||||
|
choices.
|
||||||
|
prompt_logprobs:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for number of log probabilities to
|
||||||
|
return for prompt tokens.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
- prompt
|
- prompt
|
||||||
|
title: OpenAICompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible completion endpoint.
|
||||||
|
OpenaiCompletionRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
params:
|
||||||
|
$ref: '#/components/schemas/OpenAICompletionRequest'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- params
|
||||||
title: OpenaiCompletionRequest
|
title: OpenaiCompletionRequest
|
||||||
OpenAICompletion:
|
OpenAICompletion:
|
||||||
type: object
|
type: object
|
||||||
|
|
33
docs/static/stainless-llama-stack-spec.html
vendored
33
docs/static/stainless-llama-stack-spec.html
vendored
|
@ -7221,7 +7221,8 @@
|
||||||
"model",
|
"model",
|
||||||
"messages"
|
"messages"
|
||||||
],
|
],
|
||||||
"title": "OpenaiChatCompletionRequest"
|
"title": "OpenaiChatCompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
|
||||||
},
|
},
|
||||||
"OpenAIChatCompletion": {
|
"OpenAIChatCompletion": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7405,7 +7406,7 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAICompletionWithInputMessages"
|
"title": "OpenAICompletionWithInputMessages"
|
||||||
},
|
},
|
||||||
"OpenaiCompletionRequest": {
|
"OpenAICompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model": {
|
||||||
|
@ -7536,18 +7537,20 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "(Optional) The user to use."
|
"description": "(Optional) The user to use."
|
||||||
},
|
},
|
||||||
|
"suffix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) The suffix that should be appended to the completion."
|
||||||
|
},
|
||||||
"guided_choice": {
|
"guided_choice": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
},
|
||||||
|
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
|
||||||
},
|
},
|
||||||
"prompt_logprobs": {
|
"prompt_logprobs": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
},
|
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
|
||||||
"suffix": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "(Optional) The suffix that should be appended to the completion."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7555,6 +7558,20 @@
|
||||||
"model",
|
"model",
|
||||||
"prompt"
|
"prompt"
|
||||||
],
|
],
|
||||||
|
"title": "OpenAICompletionRequest",
|
||||||
|
"description": "Request parameters for OpenAI-compatible completion endpoint."
|
||||||
|
},
|
||||||
|
"OpenaiCompletionRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"params": {
|
||||||
|
"$ref": "#/components/schemas/OpenAICompletionRequest"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"params"
|
||||||
|
],
|
||||||
"title": "OpenaiCompletionRequest"
|
"title": "OpenaiCompletionRequest"
|
||||||
},
|
},
|
||||||
"OpenAICompletion": {
|
"OpenAICompletion": {
|
||||||
|
|
33
docs/static/stainless-llama-stack-spec.yaml
vendored
33
docs/static/stainless-llama-stack-spec.yaml
vendored
|
@ -5365,6 +5365,8 @@ components:
|
||||||
- model
|
- model
|
||||||
- messages
|
- messages
|
||||||
title: OpenaiChatCompletionRequest
|
title: OpenaiChatCompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible chat completion endpoint.
|
||||||
OpenAIChatCompletion:
|
OpenAIChatCompletion:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5518,7 +5520,7 @@ components:
|
||||||
- model
|
- model
|
||||||
- input_messages
|
- input_messages
|
||||||
title: OpenAICompletionWithInputMessages
|
title: OpenAICompletionWithInputMessages
|
||||||
OpenaiCompletionRequest:
|
OpenAICompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model:
|
||||||
|
@ -5606,20 +5608,37 @@ components:
|
||||||
user:
|
user:
|
||||||
type: string
|
type: string
|
||||||
description: (Optional) The user to use.
|
description: (Optional) The user to use.
|
||||||
guided_choice:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
prompt_logprobs:
|
|
||||||
type: integer
|
|
||||||
suffix:
|
suffix:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
(Optional) The suffix that should be appended to the completion.
|
(Optional) The suffix that should be appended to the completion.
|
||||||
|
guided_choice:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for guided generation with a list of
|
||||||
|
choices.
|
||||||
|
prompt_logprobs:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) vLLM-specific parameter for number of log probabilities to
|
||||||
|
return for prompt tokens.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
- prompt
|
- prompt
|
||||||
|
title: OpenAICompletionRequest
|
||||||
|
description: >-
|
||||||
|
Request parameters for OpenAI-compatible completion endpoint.
|
||||||
|
OpenaiCompletionRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
params:
|
||||||
|
$ref: '#/components/schemas/OpenAICompletionRequest'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- params
|
||||||
title: OpenaiCompletionRequest
|
title: OpenaiCompletionRequest
|
||||||
OpenAICompletion:
|
OpenAICompletion:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -14,7 +14,7 @@ from typing import (
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
|
@ -995,6 +995,120 @@ class ListOpenAIChatCompletionResponse(BaseModel):
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionRequest(BaseModel):
|
||||||
|
"""Request parameters for OpenAI-compatible completion endpoint.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param prompt: The prompt to generate a completion for.
|
||||||
|
:param best_of: (Optional) The number of completions to generate.
|
||||||
|
:param echo: (Optional) Whether to echo the prompt.
|
||||||
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
|
:param n: (Optional) The number of completions to generate.
|
||||||
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
|
:param seed: (Optional) The seed to use.
|
||||||
|
:param stop: (Optional) The stop tokens to use.
|
||||||
|
:param stream: (Optional) Whether to stream the response.
|
||||||
|
:param stream_options: (Optional) The stream options to use.
|
||||||
|
:param temperature: (Optional) The temperature to use.
|
||||||
|
:param top_p: (Optional) The top p to use.
|
||||||
|
:param user: (Optional) The user to use.
|
||||||
|
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||||
|
:param guided_choice: (Optional) vLLM-specific parameter for guided generation with a list of choices.
|
||||||
|
:param prompt_logprobs: (Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
# Required parameters
|
||||||
|
model: str
|
||||||
|
prompt: str | list[str] | list[int] | list[list[int]]
|
||||||
|
|
||||||
|
# Standard OpenAI completion parameters
|
||||||
|
best_of: int | None = None
|
||||||
|
echo: bool | None = None
|
||||||
|
frequency_penalty: float | None = None
|
||||||
|
logit_bias: dict[str, float] | None = None
|
||||||
|
logprobs: bool | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
n: int | None = None
|
||||||
|
presence_penalty: float | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
stream: bool | None = None
|
||||||
|
stream_options: dict[str, Any] | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
user: str | None = None
|
||||||
|
suffix: str | None = None
|
||||||
|
|
||||||
|
# vLLM-specific parameters (documented here but also allowed via extra fields)
|
||||||
|
guided_choice: list[str] | None = None
|
||||||
|
prompt_logprobs: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenaiChatCompletionRequest(BaseModel):
|
||||||
|
"""Request parameters for OpenAI-compatible chat completion endpoint.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param messages: List of messages in the conversation.
|
||||||
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
|
:param function_call: (Optional) The function call to use.
|
||||||
|
:param functions: (Optional) List of functions to use.
|
||||||
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
|
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
|
:param n: (Optional) The number of completions to generate.
|
||||||
|
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||||
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
|
:param response_format: (Optional) The response format to use.
|
||||||
|
:param seed: (Optional) The seed to use.
|
||||||
|
:param stop: (Optional) The stop tokens to use.
|
||||||
|
:param stream: (Optional) Whether to stream the response.
|
||||||
|
:param stream_options: (Optional) The stream options to use.
|
||||||
|
:param temperature: (Optional) The temperature to use.
|
||||||
|
:param tool_choice: (Optional) The tool choice to use.
|
||||||
|
:param tools: (Optional) The tools to use.
|
||||||
|
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||||
|
:param top_p: (Optional) The top p to use.
|
||||||
|
:param user: (Optional) The user to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
# Required parameters
|
||||||
|
model: str
|
||||||
|
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
|
||||||
|
|
||||||
|
# Standard OpenAI chat completion parameters
|
||||||
|
frequency_penalty: float | None = None
|
||||||
|
function_call: str | dict[str, Any] | None = None
|
||||||
|
functions: list[dict[str, Any]] | None = None
|
||||||
|
logit_bias: dict[str, float] | None = None
|
||||||
|
logprobs: bool | None = None
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
n: int | None = None
|
||||||
|
parallel_tool_calls: bool | None = None
|
||||||
|
presence_penalty: float | None = None
|
||||||
|
response_format: OpenAIResponseFormatParam | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
stream: bool | None = None
|
||||||
|
stream_options: dict[str, Any] | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
tool_choice: str | dict[str, Any] | None = None
|
||||||
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
top_logprobs: int | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class InferenceProvider(Protocol):
|
class InferenceProvider(Protocol):
|
||||||
|
@ -1029,52 +1143,11 @@ class InferenceProvider(Protocol):
|
||||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
params: OpenAICompletionRequest,
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
# vLLM-specific parameters
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
# for fill-in-the-middle type completion
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Create completion.
|
"""Create completion.
|
||||||
|
|
||||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param prompt: The prompt to generate a completion for.
|
|
||||||
:param best_of: (Optional) The number of completions to generate.
|
|
||||||
:param echo: (Optional) Whether to echo the prompt.
|
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
|
||||||
:param logit_bias: (Optional) The logit bias to use.
|
|
||||||
:param logprobs: (Optional) The log probabilities to use.
|
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
|
||||||
:param n: (Optional) The number of completions to generate.
|
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
|
||||||
:param seed: (Optional) The seed to use.
|
|
||||||
:param stop: (Optional) The stop tokens to use.
|
|
||||||
:param stream: (Optional) Whether to stream the response.
|
|
||||||
:param stream_options: (Optional) The stream options to use.
|
|
||||||
:param temperature: (Optional) The temperature to use.
|
|
||||||
:param top_p: (Optional) The top p to use.
|
|
||||||
:param user: (Optional) The user to use.
|
|
||||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
|
||||||
:returns: An OpenAICompletion.
|
:returns: An OpenAICompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -1083,57 +1156,11 @@ class InferenceProvider(Protocol):
|
||||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""Create chat completions.
|
"""Create chat completions.
|
||||||
|
|
||||||
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param messages: List of messages in the conversation.
|
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
|
||||||
:param function_call: (Optional) The function call to use.
|
|
||||||
:param functions: (Optional) List of functions to use.
|
|
||||||
:param logit_bias: (Optional) The logit bias to use.
|
|
||||||
:param logprobs: (Optional) The log probabilities to use.
|
|
||||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
|
||||||
:param n: (Optional) The number of completions to generate.
|
|
||||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
|
||||||
:param response_format: (Optional) The response format to use.
|
|
||||||
:param seed: (Optional) The seed to use.
|
|
||||||
:param stop: (Optional) The stop tokens to use.
|
|
||||||
:param stream: (Optional) Whether to stream the response.
|
|
||||||
:param stream_options: (Optional) The stream options to use.
|
|
||||||
:param temperature: (Optional) The temperature to use.
|
|
||||||
:param tool_choice: (Optional) The tool choice to use.
|
|
||||||
:param tools: (Optional) The tools to use.
|
|
||||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
|
||||||
:param top_p: (Optional) The top p to use.
|
|
||||||
:param user: (Optional) The user to use.
|
|
||||||
:returns: An OpenAIChatCompletion.
|
:returns: An OpenAIChatCompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -383,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body, field_names = self._handle_file_uploads(options, body)
|
body, field_names = self._handle_file_uploads(options, body)
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||||
|
|
||||||
trace_path = webmethod.descriptive_name or route_path
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
await start_trace(trace_path, {"__location__": "library_client"})
|
await start_trace(trace_path, {"__location__": "library_client"})
|
||||||
|
@ -446,7 +446,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||||
|
body = self._convert_body(func, body)
|
||||||
|
|
||||||
trace_path = webmethod.descriptive_name or route_path
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
await start_trace(trace_path, {"__location__": "library_client"})
|
await start_trace(trace_path, {"__location__": "library_client"})
|
||||||
|
@ -493,17 +494,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(
|
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
|
||||||
) -> dict:
|
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
|
||||||
exclude_params = exclude_params or set()
|
exclude_params = exclude_params or set()
|
||||||
|
|
||||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||||
|
|
||||||
|
# Check if the method expects a single Pydantic model parameter
|
||||||
|
if len(params_list) == 1:
|
||||||
|
param = params_list[0]
|
||||||
|
param_type = param.annotation
|
||||||
|
if issubclass(param_type, BaseModel):
|
||||||
|
# Strip NOT_GIVENs before passing to Pydantic
|
||||||
|
clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||||
|
|
||||||
|
# If the body has a single key matching the parameter name, unwrap it
|
||||||
|
if len(clean_body) == 1 and param.name in clean_body:
|
||||||
|
clean_body = clean_body[param.name]
|
||||||
|
|
||||||
|
return {param.name: param_type(**clean_body)}
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||||
|
|
|
@ -8,11 +8,11 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Annotated, Any
|
from typing import Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAIChatCompletionToolCall,
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChatCompletionToolCallFunction,
|
OpenAIChatCompletionToolCallFunction,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAIChoiceLogprobs,
|
OpenAIChoiceLogprobs,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequest,
|
||||||
OpenAICompletionWithInputMessages,
|
OpenAICompletionWithInputMessages,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
Order,
|
Order,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenAICompletionRequest,
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||||
)
|
|
||||||
model_obj = await self._get_model(model, ModelType.llm)
|
|
||||||
params = dict(
|
|
||||||
model=model_obj.identifier,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
guided_choice=guided_choice,
|
|
||||||
prompt_logprobs=prompt_logprobs,
|
|
||||||
suffix=suffix,
|
|
||||||
)
|
)
|
||||||
|
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||||
|
|
||||||
|
# Update params with the resolved model identifier
|
||||||
|
params.model = model_obj.identifier
|
||||||
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if params.stream:
|
||||||
return await provider.openai_completion(**params)
|
return await provider.openai_completion(params)
|
||||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||||
# response_stream = await provider.openai_completion(**params)
|
|
||||||
|
|
||||||
response = await provider.openai_completion(**params)
|
response = await provider.openai_completion(params)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||||
)
|
)
|
||||||
model_obj = await self._get_model(model, ModelType.llm)
|
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||||
|
|
||||||
# Use the OpenAI client for a bit of extra input validation without
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
# exposing the OpenAI client itself as part of our API surface
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
if tool_choice:
|
if params.tool_choice:
|
||||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||||
if tools is None:
|
if params.tools is None:
|
||||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||||
if tools:
|
if params.tools:
|
||||||
for tool in tools:
|
for tool in params.tools:
|
||||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
|
||||||
# Some providers make tool calls even when tool_choice is "none"
|
# Some providers make tool calls even when tool_choice is "none"
|
||||||
# so just clear them both out to avoid unexpected tool calls
|
# so just clear them both out to avoid unexpected tool calls
|
||||||
if tool_choice == "none" and tools is not None:
|
if params.tool_choice == "none" and params.tools is not None:
|
||||||
tool_choice = None
|
params.tool_choice = None
|
||||||
tools = None
|
params.tools = None
|
||||||
|
|
||||||
|
# Update params with the resolved model identifier
|
||||||
|
params.model = model_obj.identifier
|
||||||
|
|
||||||
params = dict(
|
|
||||||
model=model_obj.identifier,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if params.stream:
|
||||||
response_stream = await provider.openai_chat_completion(**params)
|
response_stream = await provider.openai_chat_completion(params)
|
||||||
|
|
||||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||||
# We need to add metrics to each chunk and store the final completion
|
# We need to add metrics to each chunk and store the final completion
|
||||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||||
response=response_stream,
|
response=response_stream,
|
||||||
model=model_obj,
|
model=model_obj,
|
||||||
messages=messages,
|
messages=params.messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||||
|
|
||||||
# Store the response with the ID that will be returned to the client
|
# Store the response with the ID that will be returned to the client
|
||||||
if self.store:
|
if self.store:
|
||||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||||
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
|
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
|
||||||
return await self.store.get_chat_completion(completion_id)
|
return await self.store.get_chat_completion(completion_id)
|
||||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
async def _nonstream_openai_chat_completion(
|
||||||
response = await provider.openai_chat_completion(**params)
|
self, provider: Inference, params: OpenaiChatCompletionRequest
|
||||||
|
) -> OpenAIChatCompletion:
|
||||||
|
response = await provider.openai_chat_completion(params)
|
||||||
for choice in response.choices:
|
for choice in response.choices:
|
||||||
# some providers return an empty list for no tool calls in non-streaming responses
|
# some providers return an empty list for no tool calls in non-streaming responses
|
||||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||||
|
|
|
@ -13,12 +13,13 @@ import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, get_origin
|
from typing import Annotated, Any, Union, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
|
@ -177,7 +178,17 @@ async def lifespan(app: StackApp):
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||||
return kwargs.get("stream", False)
|
# Check for stream parameter at top level (old API style)
|
||||||
|
if "stream" in kwargs:
|
||||||
|
return kwargs["stream"]
|
||||||
|
|
||||||
|
# Check for stream parameter inside Pydantic request params (new API style)
|
||||||
|
if "params" in kwargs:
|
||||||
|
params = kwargs["params"]
|
||||||
|
if hasattr(params, "stream"):
|
||||||
|
return params.stream
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def maybe_await(value):
|
async def maybe_await(value):
|
||||||
|
@ -282,21 +293,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
if method == "post":
|
if method == "post":
|
||||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||||
# but preserve existing File() and Form() annotations for multipart form data
|
# but preserve existing File() and Form() annotations for multipart form data
|
||||||
new_params = (
|
def get_body_embed_value(param: inspect.Parameter) -> bool:
|
||||||
[new_params[0]]
|
"""Determine if Body should use embed=True or embed=False.
|
||||||
+ [
|
|
||||||
(
|
For OpenAI-compatible endpoints (param name is 'params'), use embed=False
|
||||||
|
so the request body is parsed directly as the model (not nested).
|
||||||
|
This allows OpenAI clients to send standard OpenAI format.
|
||||||
|
For other endpoints, use embed=True for SDK compatibility.
|
||||||
|
"""
|
||||||
|
# Get the actual type, stripping Optional/Union if present
|
||||||
|
param_type = param.annotation
|
||||||
|
origin = get_origin(param_type)
|
||||||
|
# Check for Union types (both typing.Union and types.UnionType for | syntax)
|
||||||
|
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
||||||
|
# Handle Optional[T] / T | None
|
||||||
|
args = param_type.__args__ if hasattr(param_type, "__args__") else []
|
||||||
|
param_type = next((arg for arg in args if arg is not type(None)), param_type)
|
||||||
|
|
||||||
|
# Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible)
|
||||||
|
is_basemodel = issubclass(param_type, BaseModel)
|
||||||
|
if is_basemodel and param.name == "params":
|
||||||
|
return False # Use embed=False for OpenAI-compatible endpoints
|
||||||
|
return True # Use embed=True for everything else
|
||||||
|
|
||||||
|
original_params = new_params[1:] # Skip request parameter
|
||||||
|
new_params = [new_params[0]] # Keep request parameter
|
||||||
|
|
||||||
|
for param in original_params:
|
||||||
|
if param.name in path_params:
|
||||||
|
new_params.append(
|
||||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||||
if param.name in path_params
|
|
||||||
else (
|
|
||||||
param # Keep original annotation if it's already an Annotated type
|
|
||||||
if get_origin(param.annotation) is Annotated
|
|
||||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for param in new_params[1:]
|
elif get_origin(param.annotation) is Annotated:
|
||||||
]
|
new_params.append(param) # Keep existing annotation
|
||||||
)
|
else:
|
||||||
|
embed = get_body_embed_value(param)
|
||||||
|
new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)]))
|
||||||
|
|
||||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
Message,
|
Message,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAIDeveloperMessageParam,
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
|
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||||
|
|
||||||
# Use OpenAI chat completion
|
# Use OpenAI chat completion
|
||||||
openai_stream = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=self.agent_config.model,
|
model=self.agent_config.model,
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
tools=openai_tools if openai_tools else None,
|
tools=openai_tools if openai_tools else None,
|
||||||
|
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
openai_stream = await self.inference_api.openai_chat_completion(params)
|
||||||
|
|
||||||
# Convert OpenAI stream back to Llama Stack format
|
# Convert OpenAI stream back to Llama Stack format
|
||||||
response_stream = convert_openai_chat_completion_stream(
|
response_stream = convert_openai_chat_completion_stream(
|
||||||
|
|
|
@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAIChatCompletionToolCall,
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
@ -130,7 +131,7 @@ class StreamingResponseOrchestrator:
|
||||||
# (some providers don't support non-empty response_format when tools are present)
|
# (some providers don't support non-empty response_format when tools are present)
|
||||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||||
completion_result = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.ctx.chat_tools,
|
tools=self.ctx.chat_tools,
|
||||||
|
@ -138,6 +139,7 @@ class StreamingResponseOrchestrator:
|
||||||
temperature=self.ctx.temperature,
|
temperature=self.ctx.temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
completion_result = await self.inference_api.openai_chat_completion(params)
|
||||||
|
|
||||||
# Process streaming chunks and build complete response
|
# Process streaming chunks and build complete response
|
||||||
completion_result_data = None
|
completion_result_data = None
|
||||||
|
|
|
@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
OpenAIDeveloperMessageParam,
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
|
@ -601,7 +603,8 @@ class ReferenceBatchesImpl(Batches):
|
||||||
# TODO(SECURITY): review body for security issues
|
# TODO(SECURITY): review body for security issues
|
||||||
if request.url == "/v1/chat/completions":
|
if request.url == "/v1/chat/completions":
|
||||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
params = OpenaiChatCompletionRequest(**request.body)
|
||||||
|
chat_response = await self.inference_api.openai_chat_completion(params)
|
||||||
|
|
||||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||||
|
@ -615,7 +618,8 @@ class ReferenceBatchesImpl(Batches):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else: # /v1/completions
|
else: # /v1/completions
|
||||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
params = OpenAICompletionRequest(**request.body)
|
||||||
|
completion_response = await self.inference_api.openai_completion(params)
|
||||||
|
|
||||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
assert hasattr(completion_response, "model_dump_json"), (
|
assert hasattr(completion_response, "model_dump_json"), (
|
||||||
|
|
|
@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
|
||||||
from llama_stack.apis.benchmarks import Benchmark
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
|
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
|
||||||
sampling_params["stop"] = candidate.sampling_params.stop
|
sampling_params["stop"] = candidate.sampling_params.stop
|
||||||
|
|
||||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||||
response = await self.inference_api.openai_completion(
|
params = OpenAICompletionRequest(
|
||||||
model=candidate.model,
|
model=candidate.model,
|
||||||
prompt=input_content,
|
prompt=input_content,
|
||||||
**sampling_params,
|
**sampling_params,
|
||||||
)
|
)
|
||||||
|
response = await self.inference_api.openai_completion(params)
|
||||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
|
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
|
||||||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||||
|
|
||||||
messages += input_messages
|
messages += input_messages
|
||||||
response = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=candidate.model,
|
model=candidate.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**sampling_params,
|
**sampling_params,
|
||||||
)
|
)
|
||||||
|
response = await self.inference_api.openai_chat_completion(params)
|
||||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input row")
|
raise ValueError("Invalid input row")
|
||||||
|
|
|
@ -6,16 +6,16 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAIMessageParam,
|
OpenAICompletion,
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
async def openai_completion(self, *args, **kwargs):
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
params: OpenAICompletionRequest,
|
||||||
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
|
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||||
|
|
|
@ -5,17 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
params: OpenAICompletionRequest,
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
# vLLM-specific parameters
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
# for fill-in-the-middle type completion
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||||
|
|
|
@ -10,7 +10,13 @@ from string import Template
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
from llama_stack.apis.inference import Inference, Message, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
Message,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
Safety,
|
Safety,
|
||||||
|
@ -290,20 +296,21 @@ class LlamaGuardShield:
|
||||||
else:
|
else:
|
||||||
shield_input_message = self.build_text_shield_input(messages)
|
shield_input_message = self.build_text_shield_input(messages)
|
||||||
|
|
||||||
response = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=False,
|
stream=False,
|
||||||
temperature=0.0, # default is 1, which is too high for safety
|
temperature=0.0, # default is 1, which is too high for safety
|
||||||
)
|
)
|
||||||
|
response = await self.inference_api.openai_chat_completion(params)
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
return self.get_shield_response(content)
|
return self.get_shield_response(content)
|
||||||
|
|
||||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||||
return UserMessage(content=self.build_prompt(messages))
|
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
|
||||||
|
|
||||||
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
|
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||||
conversation = []
|
conversation = []
|
||||||
most_recent_img = None
|
most_recent_img = None
|
||||||
|
|
||||||
|
@ -335,7 +342,7 @@ class LlamaGuardShield:
|
||||||
prompt.append(most_recent_img)
|
prompt.append(most_recent_img)
|
||||||
prompt.append(self.build_prompt(conversation[::-1]))
|
prompt.append(self.build_prompt(conversation[::-1]))
|
||||||
|
|
||||||
return UserMessage(content=prompt)
|
return OpenAIUserMessageParam(role="user", content=prompt)
|
||||||
|
|
||||||
def build_prompt(self, messages: list[Message]) -> str:
|
def build_prompt(self, messages: list[Message]) -> str:
|
||||||
categories = self.get_safety_categories()
|
categories = self.get_safety_categories()
|
||||||
|
@ -377,11 +384,12 @@ class LlamaGuardShield:
|
||||||
# TODO: Add Image based support for OpenAI Moderations
|
# TODO: Add Image based support for OpenAI Moderations
|
||||||
shield_input_message = self.build_text_shield_input(messages)
|
shield_input_message = self.build_text_shield_input(messages)
|
||||||
|
|
||||||
response = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
response = await self.inference_api.openai_chat_completion(params)
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
return self.get_moderation_object(content)
|
return self.get_moderation_object(content)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference, OpenaiChatCompletionRequest
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||||
generated_answer=generated_answer,
|
generated_answer=generated_answer,
|
||||||
)
|
)
|
||||||
|
|
||||||
judge_response = await self.inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=fn_def.params.judge_model,
|
model=fn_def.params.judge_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
judge_response = await self.inference_api.openai_chat_completion(params)
|
||||||
content = judge_response.choices[0].message.content
|
content = judge_response.choices[0].message.content
|
||||||
rating_regexes = fn_def.params.judge_score_regexes
|
rating_regexes = fn_def.params.judge_score_regexes
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
from llama_stack.apis.inference import OpenaiChatCompletionRequest, OpenAIUserMessageParam
|
||||||
from llama_stack.apis.tools.rag_tool import (
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
DefaultRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = OpenAIUserMessageParam(content=rendered_content)
|
message = OpenAIUserMessageParam(content=rendered_content)
|
||||||
response = await inference_api.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
response = await inference_api.openai_chat_completion(params)
|
||||||
|
|
||||||
query = response.choices[0].message.content
|
query = response.choices[0].message.content
|
||||||
|
|
||||||
|
|
|
@ -6,21 +6,20 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
Inference,
|
Inference,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
@ -135,56 +134,12 @@ class BedrockInferenceAdapter(
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
params: OpenAICompletionRequest,
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
# vLLM-specific parameters
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
# for fill-in-the-middle type completion
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||||
|
|
|
@ -5,11 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from databricks.sdk import WorkspaceClient
|
from databricks.sdk import WorkspaceClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import OpenAICompletion
|
from llama_stack.apis.inference import OpenAICompletion
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_stack.apis.inference import OpenAICompletionRequest
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -43,25 +46,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: "OpenAICompletionRequest",
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -3,9 +3,12 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_stack.apis.inference import OpenAICompletionRequest
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
@ -34,26 +37,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: "OpenAICompletionRequest",
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -13,15 +13,14 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequest,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
|
||||||
|
|
||||||
from .config import PassthroughImplConfig
|
from .config import PassthroughImplConfig
|
||||||
|
|
||||||
|
@ -80,110 +79,33 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenAICompletionRequest,
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
# Copy params to avoid mutating the original
|
||||||
model=model_obj.provider_resource_id,
|
params = params.model_copy()
|
||||||
prompt=prompt,
|
params.model = model_obj.provider_resource_id
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
guided_choice=guided_choice,
|
|
||||||
prompt_logprobs=prompt_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await client.inference.openai_completion(**params)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
return await client.inference.openai_completion(**request_params)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
# Copy params to avoid mutating the original
|
||||||
model=model_obj.provider_resource_id,
|
params = params.model_copy()
|
||||||
messages=messages,
|
params.model = model_obj.provider_resource_id
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await client.inference.openai_chat_completion(**params)
|
request_params = params.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
return await client.inference.openai_chat_completion(**request_params)
|
||||||
|
|
||||||
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
||||||
json_params = {}
|
json_params = {}
|
||||||
|
|
|
@ -4,11 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIChatCompletion,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -34,56 +35,13 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
):
|
|
||||||
"""Override to add RunPod-specific stream_options requirement."""
|
"""Override to add RunPod-specific stream_options requirement."""
|
||||||
if stream and not stream_options:
|
# Copy params to avoid mutating the original
|
||||||
stream_options = {"include_usage": True}
|
params = params.model_copy()
|
||||||
|
|
||||||
return await super().openai_chat_completion(
|
if params.stream and not params.stream_options:
|
||||||
model=model,
|
params.stream_options = {"include_usage": True}
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
return await super().openai_chat_completion(params)
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -15,8 +14,7 @@ from pydantic import ConfigDict
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIMessageParam,
|
OpenaiChatCompletionRequest,
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -79,61 +77,20 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: "OpenaiChatCompletionRequest",
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
max_tokens = max_tokens or self.config.max_tokens
|
# Copy params to avoid mutating the original
|
||||||
|
params = params.model_copy()
|
||||||
|
|
||||||
|
# Apply vLLM-specific defaults
|
||||||
|
if params.max_tokens is None and self.config.max_tokens:
|
||||||
|
params.max_tokens = self.config.max_tokens
|
||||||
|
|
||||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||||
# References:
|
# References:
|
||||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||||
# * https://github.com/vllm-project/vllm/pull/10000
|
# * https://github.com/vllm-project/vllm/pull/10000
|
||||||
if not tools and tool_choice is not None:
|
if not params.tools and params.tool_choice is not None:
|
||||||
tool_choice = ToolChoice.none.value
|
params.tool_choice = ToolChoice.none.value
|
||||||
|
|
||||||
return await super().openai_chat_completion(
|
return await super().openai_chat_completion(params)
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
@ -17,12 +16,12 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAICompletionRequest,
|
||||||
OpenAIEmbeddingData,
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
@ -227,116 +226,88 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenAICompletionRequest,
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
|
# Extract extra fields
|
||||||
|
extra_body = dict(params.__pydantic_extra__ or {})
|
||||||
|
|
||||||
|
request_params = await prepare_openai_completion_params(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
prompt=prompt,
|
prompt=params.prompt,
|
||||||
best_of=best_of,
|
best_of=params.best_of,
|
||||||
echo=echo,
|
echo=params.echo,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=params.frequency_penalty,
|
||||||
logit_bias=logit_bias,
|
logit_bias=params.logit_bias,
|
||||||
logprobs=logprobs,
|
logprobs=params.logprobs,
|
||||||
max_tokens=max_tokens,
|
max_tokens=params.max_tokens,
|
||||||
n=n,
|
n=params.n,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=params.presence_penalty,
|
||||||
seed=seed,
|
seed=params.seed,
|
||||||
stop=stop,
|
stop=params.stop,
|
||||||
stream=stream,
|
stream=params.stream,
|
||||||
stream_options=stream_options,
|
stream_options=params.stream_options,
|
||||||
temperature=temperature,
|
temperature=params.temperature,
|
||||||
top_p=top_p,
|
top_p=params.top_p,
|
||||||
user=user,
|
user=params.user,
|
||||||
guided_choice=guided_choice,
|
guided_choice=params.guided_choice,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=params.prompt_logprobs,
|
||||||
|
suffix=params.suffix,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
|
**extra_body,
|
||||||
)
|
)
|
||||||
return await litellm.atext_completion(**params)
|
return await litellm.atext_completion(**request_params)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenaiChatCompletionRequest,
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
# Add usage tracking for streaming when telemetry is active
|
# Add usage tracking for streaming when telemetry is active
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
if stream and get_current_span() is not None:
|
stream_options = params.stream_options
|
||||||
|
if params.stream and get_current_span() is not None:
|
||||||
if stream_options is None:
|
if stream_options is None:
|
||||||
stream_options = {"include_usage": True}
|
stream_options = {"include_usage": True}
|
||||||
elif "include_usage" not in stream_options:
|
elif "include_usage" not in stream_options:
|
||||||
stream_options = {**stream_options, "include_usage": True}
|
stream_options = {**stream_options, "include_usage": True}
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
|
||||||
|
# Extract extra fields
|
||||||
|
extra_body = dict(params.__pydantic_extra__ or {})
|
||||||
|
|
||||||
|
request_params = await prepare_openai_completion_params(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
messages=messages,
|
messages=params.messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=params.frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=params.function_call,
|
||||||
functions=functions,
|
functions=params.functions,
|
||||||
logit_bias=logit_bias,
|
logit_bias=params.logit_bias,
|
||||||
logprobs=logprobs,
|
logprobs=params.logprobs,
|
||||||
max_completion_tokens=max_completion_tokens,
|
max_completion_tokens=params.max_completion_tokens,
|
||||||
max_tokens=max_tokens,
|
max_tokens=params.max_tokens,
|
||||||
n=n,
|
n=params.n,
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=params.parallel_tool_calls,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=params.presence_penalty,
|
||||||
response_format=response_format,
|
response_format=params.response_format,
|
||||||
seed=seed,
|
seed=params.seed,
|
||||||
stop=stop,
|
stop=params.stop,
|
||||||
stream=stream,
|
stream=params.stream,
|
||||||
stream_options=stream_options,
|
stream_options=stream_options,
|
||||||
temperature=temperature,
|
temperature=params.temperature,
|
||||||
tool_choice=tool_choice,
|
tool_choice=params.tool_choice,
|
||||||
tools=tools,
|
tools=params.tools,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=params.top_logprobs,
|
||||||
top_p=top_p,
|
top_p=params.top_p,
|
||||||
user=user,
|
user=params.user,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
|
**extra_body,
|
||||||
)
|
)
|
||||||
return await litellm.acompletion(**params)
|
return await litellm.acompletion(**request_params)
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,7 +8,7 @@ import base64
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncIterator, Iterable
|
from collections.abc import AsyncIterator, Iterable
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from openai import NOT_GIVEN, AsyncOpenAI
|
from openai import NOT_GIVEN, AsyncOpenAI
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
@ -22,8 +22,13 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
|
OpenAICompletionRequest,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -227,96 +232,55 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: "OpenAICompletionRequest",
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""
|
"""
|
||||||
Direct OpenAI completion API call.
|
Direct OpenAI completion API call.
|
||||||
"""
|
"""
|
||||||
# Handle parameters that are not supported by OpenAI API, but may be by the provider
|
# Extract extra fields using Pydantic's built-in __pydantic_extra__
|
||||||
# prompt_logprobs is supported by vLLM
|
extra_body = dict(params.__pydantic_extra__ or {})
|
||||||
# guided_choice is supported by vLLM
|
|
||||||
# TODO: test coverage
|
# Add vLLM-specific parameters to extra_body if they are set
|
||||||
extra_body: dict[str, Any] = {}
|
# (these are explicitly defined in the model but still go to extra_body)
|
||||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
if params.prompt_logprobs is not None and params.prompt_logprobs >= 0:
|
||||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
extra_body["prompt_logprobs"] = params.prompt_logprobs
|
||||||
if guided_choice:
|
if params.guided_choice:
|
||||||
extra_body["guided_choice"] = guided_choice
|
extra_body["guided_choice"] = params.guided_choice
|
||||||
|
|
||||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||||
resp = await self.client.completions.create(
|
completion_kwargs = await prepare_openai_completion_params(
|
||||||
**await prepare_openai_completion_params(
|
model=await self._get_provider_model_id(params.model),
|
||||||
model=await self._get_provider_model_id(model),
|
prompt=params.prompt,
|
||||||
prompt=prompt,
|
best_of=params.best_of,
|
||||||
best_of=best_of,
|
echo=params.echo,
|
||||||
echo=echo,
|
frequency_penalty=params.frequency_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
logit_bias=params.logit_bias,
|
||||||
logit_bias=logit_bias,
|
logprobs=params.logprobs,
|
||||||
logprobs=logprobs,
|
max_tokens=params.max_tokens,
|
||||||
max_tokens=max_tokens,
|
n=params.n,
|
||||||
n=n,
|
presence_penalty=params.presence_penalty,
|
||||||
presence_penalty=presence_penalty,
|
seed=params.seed,
|
||||||
seed=seed,
|
stop=params.stop,
|
||||||
stop=stop,
|
stream=params.stream,
|
||||||
stream=stream,
|
stream_options=params.stream_options,
|
||||||
stream_options=stream_options,
|
temperature=params.temperature,
|
||||||
temperature=temperature,
|
top_p=params.top_p,
|
||||||
top_p=top_p,
|
user=params.user,
|
||||||
user=user,
|
suffix=params.suffix,
|
||||||
suffix=suffix,
|
|
||||||
),
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
)
|
||||||
|
resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body)
|
||||||
|
|
||||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: "OpenaiChatCompletionRequest",
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""
|
"""
|
||||||
Direct OpenAI chat completion API call.
|
Direct OpenAI chat completion API call.
|
||||||
"""
|
"""
|
||||||
|
messages = params.messages
|
||||||
|
|
||||||
if self.download_images:
|
if self.download_images:
|
||||||
|
|
||||||
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||||
|
@ -335,35 +299,38 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
|
|
||||||
messages = [await _localize_image_url(m) for m in messages]
|
messages = [await _localize_image_url(m) for m in messages]
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
request_params = await prepare_openai_completion_params(
|
||||||
model=await self._get_provider_model_id(model),
|
model=await self._get_provider_model_id(params.model),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=params.frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=params.function_call,
|
||||||
functions=functions,
|
functions=params.functions,
|
||||||
logit_bias=logit_bias,
|
logit_bias=params.logit_bias,
|
||||||
logprobs=logprobs,
|
logprobs=params.logprobs,
|
||||||
max_completion_tokens=max_completion_tokens,
|
max_completion_tokens=params.max_completion_tokens,
|
||||||
max_tokens=max_tokens,
|
max_tokens=params.max_tokens,
|
||||||
n=n,
|
n=params.n,
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=params.parallel_tool_calls,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=params.presence_penalty,
|
||||||
response_format=response_format,
|
response_format=params.response_format,
|
||||||
seed=seed,
|
seed=params.seed,
|
||||||
stop=stop,
|
stop=params.stop,
|
||||||
stream=stream,
|
stream=params.stream,
|
||||||
stream_options=stream_options,
|
stream_options=params.stream_options,
|
||||||
temperature=temperature,
|
temperature=params.temperature,
|
||||||
tool_choice=tool_choice,
|
tool_choice=params.tool_choice,
|
||||||
tools=tools,
|
tools=params.tools,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=params.top_logprobs,
|
||||||
top_p=top_p,
|
top_p=params.top_p,
|
||||||
user=user,
|
user=params.user,
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = await self.client.chat.completions.create(**params)
|
# Extract any additional provider-specific parameters using Pydantic's __pydantic_extra__
|
||||||
|
if extra_body := dict(params.__pydantic_extra__ or {}):
|
||||||
|
request_params["extra_body"] = extra_body
|
||||||
|
resp = await self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -146,14 +146,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
# For streaming response, collect all chunks
|
# For streaming response, collect all chunks
|
||||||
chunks = [chunk async for chunk in result]
|
chunks = [chunk async for chunk in result]
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
# Verify the inference API was called with the correct params
|
||||||
model=model,
|
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
params = call_args.args[0] # params is passed as first positional arg
|
||||||
response_format=None,
|
assert params.model == model
|
||||||
tools=None,
|
assert params.messages == [
|
||||||
stream=True,
|
OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)
|
||||||
temperature=0.1,
|
]
|
||||||
)
|
assert params.response_format is None
|
||||||
|
assert params.tools is None
|
||||||
|
assert params.stream is True
|
||||||
|
assert params.temperature == 0.1
|
||||||
|
|
||||||
# Should have content part events for text streaming
|
# Should have content part events for text streaming
|
||||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||||
|
@ -228,13 +231,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
first_params = first_call.args[0]
|
||||||
assert first_call.kwargs["tools"] is not None
|
assert first_params.messages[0].content == "What is the capital of Ireland?"
|
||||||
assert first_call.kwargs["temperature"] == 0.1
|
assert first_params.tools is not None
|
||||||
|
assert first_params.temperature == 0.1
|
||||||
|
|
||||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
second_params = second_call.args[0]
|
||||||
assert second_call.kwargs["temperature"] == 0.1
|
assert second_params.messages[-1].content == "Dublin"
|
||||||
|
assert second_params.temperature == 0.1
|
||||||
|
|
||||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||||
|
@ -309,9 +314,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
||||||
|
|
||||||
# Verify inference API was called correctly (after iterating over result)
|
# Verify inference API was called correctly (after iterating over result)
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == input_text
|
first_params = first_call.args[0]
|
||||||
assert first_call.kwargs["tools"] is not None
|
assert first_params.messages[0].content == input_text
|
||||||
assert first_call.kwargs["temperature"] == 0.1
|
assert first_params.tools is not None
|
||||||
|
assert first_params.temperature == 0.1
|
||||||
|
|
||||||
# Check response.created event (should have empty output)
|
# Check response.created event (should have empty output)
|
||||||
assert chunks[0].type == "response.created"
|
assert chunks[0].type == "response.created"
|
||||||
|
@ -386,9 +392,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
||||||
|
|
||||||
# Verify inference API was called correctly (after iterating over result)
|
# Verify inference API was called correctly (after iterating over result)
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == input_text
|
first_params = first_call.args[0]
|
||||||
assert first_call.kwargs["tools"] is not None
|
assert first_params.messages[0].content == input_text
|
||||||
assert first_call.kwargs["temperature"] == 0.1
|
assert first_params.tools is not None
|
||||||
|
assert first_params.temperature == 0.1
|
||||||
|
|
||||||
# Check response.created event (should have empty output)
|
# Check response.created event (should have empty output)
|
||||||
assert chunks[0].type == "response.created"
|
assert chunks[0].type == "response.created"
|
||||||
|
@ -435,9 +442,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
||||||
|
|
||||||
# Verify inference API was called correctly (after iterating over result)
|
# Verify inference API was called correctly (after iterating over result)
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == input_text
|
first_params = first_call.args[0]
|
||||||
assert first_call.kwargs["tools"] is not None
|
assert first_params.messages[0].content == input_text
|
||||||
assert first_call.kwargs["temperature"] == 0.1
|
assert first_params.tools is not None
|
||||||
|
assert first_params.temperature == 0.1
|
||||||
|
|
||||||
# Check response.created event (should have empty output)
|
# Check response.created event (should have empty output)
|
||||||
assert chunks[0].type == "response.created"
|
assert chunks[0].type == "response.created"
|
||||||
|
@ -485,7 +493,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
||||||
|
|
||||||
# Verify the the correct messages were sent to the inference API i.e.
|
# Verify the the correct messages were sent to the inference API i.e.
|
||||||
# All of the responses message were convered to the chat completion message objects
|
# All of the responses message were convered to the chat completion message objects
|
||||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
|
params = call_args.args[0]
|
||||||
|
inference_messages = params.messages
|
||||||
for i, m in enumerate(input_messages):
|
for i, m in enumerate(input_messages):
|
||||||
if isinstance(m.content, str):
|
if isinstance(m.content, str):
|
||||||
assert inference_messages[i].content == m.content
|
assert inference_messages[i].content == m.content
|
||||||
|
@ -653,7 +663,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
||||||
# Verify
|
# Verify
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||||
sent_messages = call_args.kwargs["messages"]
|
params = call_args.args[0]
|
||||||
|
sent_messages = params.messages
|
||||||
|
|
||||||
# Check that instructions were prepended as a system message
|
# Check that instructions were prepended as a system message
|
||||||
assert len(sent_messages) == 2
|
assert len(sent_messages) == 2
|
||||||
|
@ -691,7 +702,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||||
# Verify
|
# Verify
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||||
sent_messages = call_args.kwargs["messages"]
|
params = call_args.args[0]
|
||||||
|
sent_messages = params.messages
|
||||||
|
|
||||||
# Check that instructions were prepended as a system message
|
# Check that instructions were prepended as a system message
|
||||||
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
||||||
|
@ -751,7 +763,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
# Verify
|
# Verify
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||||
sent_messages = call_args.kwargs["messages"]
|
params = call_args.args[0]
|
||||||
|
sent_messages = params.messages
|
||||||
|
|
||||||
# Check that instructions were prepended as a system message
|
# Check that instructions were prepended as a system message
|
||||||
assert len(sent_messages) == 4, sent_messages
|
assert len(sent_messages) == 4, sent_messages
|
||||||
|
@ -987,8 +1000,9 @@ async def test_create_openai_response_with_text_format(
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == input_text
|
first_params = first_call.args[0]
|
||||||
assert first_call.kwargs["response_format"] == response_format
|
assert first_params.messages[0].content == input_text
|
||||||
|
assert first_params.response_format == response_format
|
||||||
|
|
||||||
|
|
||||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||||
|
|
|
@ -13,6 +13,7 @@ import pytest
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
OpenaiChatCompletionRequest,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
@ -56,13 +57,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
mock_client_property.return_value = mock_client
|
mock_client_property.return_value = mock_client
|
||||||
|
|
||||||
# No tools but auto tool choice
|
# No tools but auto tool choice
|
||||||
await vllm_inference_adapter.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
"mock-model",
|
model="mock-model",
|
||||||
[],
|
messages=[{"role": "user", "content": "test"}],
|
||||||
stream=False,
|
stream=False,
|
||||||
tools=None,
|
tools=None,
|
||||||
tool_choice=ToolChoice.auto.value,
|
tool_choice=ToolChoice.auto.value,
|
||||||
)
|
)
|
||||||
|
await vllm_inference_adapter.openai_chat_completion(params)
|
||||||
mock_client.chat.completions.create.assert_called()
|
mock_client.chat.completions.create.assert_called()
|
||||||
call_args = mock_client.chat.completions.create.call_args
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
# Ensure tool_choice gets converted to none for older vLLM versions
|
# Ensure tool_choice gets converted to none for older vLLM versions
|
||||||
|
@ -171,9 +173,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def do_inference():
|
async def do_inference():
|
||||||
await vllm_inference_adapter.openai_chat_completion(
|
params = OpenaiChatCompletionRequest(
|
||||||
"mock-model", messages=["one fish", "two fish"], stream=False
|
model="mock-model",
|
||||||
|
messages=[{"role": "user", "content": "one fish two fish"}],
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
|
await vllm_inference_adapter.openai_chat_completion(params)
|
||||||
|
|
||||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
|
@ -186,3 +191,48 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||||
|
|
||||||
assert mock_create_client.call_count == 4 # no cheating
|
assert mock_create_client.call_count == 4 # no cheating
|
||||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extra_body_forwarding(vllm_inference_adapter):
|
||||||
|
"""
|
||||||
|
Test that extra_body parameters (e.g., chat_template_kwargs) are correctly
|
||||||
|
forwarded to the underlying OpenAI client.
|
||||||
|
"""
|
||||||
|
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||||
|
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||||
|
|
||||||
|
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=OpenAIChatCompletion(
|
||||||
|
id="chatcmpl-abc123",
|
||||||
|
created=1,
|
||||||
|
model="mock-model",
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=OpenAIAssistantMessageParam(
|
||||||
|
content="test response",
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_client_property.return_value = mock_client
|
||||||
|
|
||||||
|
# Test with chat_template_kwargs for Granite thinking mode
|
||||||
|
params = OpenaiChatCompletionRequest(
|
||||||
|
model="mock-model",
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
stream=False,
|
||||||
|
chat_template_kwargs={"thinking": True},
|
||||||
|
)
|
||||||
|
await vllm_inference_adapter.openai_chat_completion(params)
|
||||||
|
|
||||||
|
# Verify that the client was called with extra_body containing chat_template_kwargs
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
|
||||||
|
assert "extra_body" in call_kwargs
|
||||||
|
assert "chat_template_kwargs" in call_kwargs["extra_body"]
|
||||||
|
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}
|
||||||
|
|
|
@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
from llama_stack.apis.inference import Model, OpenaiChatCompletionRequest, OpenAIUserMessageParam
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.core.request_headers import request_provider_data_context
|
from llama_stack.core.request_headers import request_provider_data_context
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
@ -271,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
|
||||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||||
|
|
||||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
|
||||||
|
await mixin.openai_chat_completion(params)
|
||||||
|
|
||||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||||
|
|
||||||
|
@ -303,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
|
||||||
|
|
||||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
|
||||||
|
await mixin.openai_chat_completion(params)
|
||||||
|
|
||||||
mock_localize.assert_not_called()
|
mock_localize.assert_not_called()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue