mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into feat/add-dana-agent-provider-stub
This commit is contained in:
commit
3f85df3da2
62 changed files with 3463 additions and 3817 deletions
|
|
@ -963,7 +963,7 @@ paths:
|
|||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
returns all non-deprecated routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
|
|
@ -998,39 +998,6 @@ paths:
|
|||
description: List models using the OpenAI API.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Model.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Model'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Register model.
|
||||
description: >-
|
||||
Register model.
|
||||
|
||||
Register a model.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/models/{model_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1065,36 +1032,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Unregister model.
|
||||
description: >-
|
||||
Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
parameters:
|
||||
- name: model_id
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the model to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/moderations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1725,32 +1662,6 @@ paths:
|
|||
description: List all scoring functions.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Register a scoring function.
|
||||
description: Register a scoring function.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/scoring-functions/{scoring_fn_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1782,33 +1693,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Unregister a scoring function.
|
||||
description: Unregister a scoring function.
|
||||
parameters:
|
||||
- name: scoring_fn_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the scoring function to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/scoring/score:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1897,36 +1781,6 @@ paths:
|
|||
description: List all shields.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Shield.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Shield'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Register a shield.
|
||||
description: Register a shield.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterShieldRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/shields/{identifier}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1958,33 +1812,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Unregister a shield.
|
||||
description: Unregister a shield.
|
||||
parameters:
|
||||
- name: identifier
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the shield to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2080,32 +1907,6 @@ paths:
|
|||
description: List tool groups with optional provider.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Register a tool group.
|
||||
description: Register a tool group.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterToolGroupRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups/{toolgroup_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2137,32 +1938,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Unregister a tool group.
|
||||
description: Unregister a tool group.
|
||||
parameters:
|
||||
- name: toolgroup_id
|
||||
in: path
|
||||
description: The ID of the tool group to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tools:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2916,11 +2691,11 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A list of InterleavedContent representing the file contents.
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
|
||||
$ref: '#/components/schemas/VectorStoreFileContentResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
|
@ -3171,7 +2946,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterDatasetRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1beta/datasets/{dataset_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3228,7 +3003,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3279,7 +3054,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterBenchmarkRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3336,7 +3111,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -6280,46 +6055,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RegisterModelRequest:
|
||||
type: object
|
||||
properties:
|
||||
model_id:
|
||||
type: string
|
||||
description: The identifier of the model to register.
|
||||
provider_model_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the model in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Any additional metadata for this model.
|
||||
model_type:
|
||||
$ref: '#/components/schemas/ModelType'
|
||||
description: The type of model to register.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
title: RegisterModelRequest
|
||||
Model:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6377,6 +6112,15 @@ components:
|
|||
title: Model
|
||||
description: >-
|
||||
A model resource representing an AI model registered in Llama Stack.
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RunModerationRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6882,6 +6626,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
input:
|
||||
type: array
|
||||
items:
|
||||
|
|
@ -7240,6 +6989,11 @@ components:
|
|||
(Optional) Additional fields to include in the response.
|
||||
max_infer_iters:
|
||||
type: integer
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input
|
||||
|
|
@ -7321,6 +7075,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
additionalProperties: false
|
||||
required:
|
||||
- created_at
|
||||
|
|
@ -9115,61 +8874,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListScoringFunctionsResponse
|
||||
ParamType:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/StringType'
|
||||
- $ref: '#/components/schemas/NumberType'
|
||||
- $ref: '#/components/schemas/BooleanType'
|
||||
- $ref: '#/components/schemas/ArrayType'
|
||||
- $ref: '#/components/schemas/ObjectType'
|
||||
- $ref: '#/components/schemas/JsonType'
|
||||
- $ref: '#/components/schemas/UnionType'
|
||||
- $ref: '#/components/schemas/ChatCompletionInputType'
|
||||
- $ref: '#/components/schemas/CompletionInputType'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
string: '#/components/schemas/StringType'
|
||||
number: '#/components/schemas/NumberType'
|
||||
boolean: '#/components/schemas/BooleanType'
|
||||
array: '#/components/schemas/ArrayType'
|
||||
object: '#/components/schemas/ObjectType'
|
||||
json: '#/components/schemas/JsonType'
|
||||
union: '#/components/schemas/UnionType'
|
||||
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
||||
completion_input: '#/components/schemas/CompletionInputType'
|
||||
RegisterScoringFunctionRequest:
|
||||
type: object
|
||||
properties:
|
||||
scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the scoring function to register.
|
||||
description:
|
||||
type: string
|
||||
description: The description of the scoring function.
|
||||
return_type:
|
||||
$ref: '#/components/schemas/ParamType'
|
||||
description: The return type of the scoring function.
|
||||
provider_scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider scoring function to use for the scoring function.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the scoring function.
|
||||
params:
|
||||
$ref: '#/components/schemas/ScoringFnParams'
|
||||
description: >-
|
||||
The parameters for the scoring function for benchmark eval, these can
|
||||
be overridden for app eval.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- scoring_fn_id
|
||||
- description
|
||||
- return_type
|
||||
title: RegisterScoringFunctionRequest
|
||||
ScoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9345,35 +9049,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListShieldsResponse
|
||||
RegisterShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield to register.
|
||||
provider_shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
params:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The parameters of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- shield_id
|
||||
title: RegisterShieldRequest
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9634,37 +9309,6 @@ components:
|
|||
title: ListToolGroupsResponse
|
||||
description: >-
|
||||
Response containing a list of tool groups.
|
||||
RegisterToolGroupRequest:
|
||||
type: object
|
||||
properties:
|
||||
toolgroup_id:
|
||||
type: string
|
||||
description: The ID of the tool group to register.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the tool group.
|
||||
mcp_endpoint:
|
||||
$ref: '#/components/schemas/URL'
|
||||
description: >-
|
||||
The MCP endpoint to use for the tool group.
|
||||
args:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
A dictionary of arguments to pass to the tool group.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- toolgroup_id
|
||||
- provider_id
|
||||
title: RegisterToolGroupRequest
|
||||
Chunk:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10465,41 +10109,35 @@ components:
|
|||
title: VectorStoreContent
|
||||
description: >-
|
||||
Content item from a vector store file or search result.
|
||||
VectorStoreFileContentsResponse:
|
||||
VectorStoreFileContentResponse:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
object:
|
||||
type: string
|
||||
description: Unique identifier for the file
|
||||
filename:
|
||||
type: string
|
||||
description: Name of the file
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
const: vector_store.file_content.page
|
||||
default: vector_store.file_content.page
|
||||
description: >-
|
||||
Key-value attributes associated with the file
|
||||
content:
|
||||
The object type, which is always `vector_store.file_content.page`
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreContent'
|
||||
description: List of content items from the file
|
||||
description: Parsed content of the file
|
||||
has_more:
|
||||
type: boolean
|
||||
description: >-
|
||||
Indicates if there are more content pages to fetch
|
||||
next_page:
|
||||
type: string
|
||||
description: The token for the next page, if any
|
||||
additionalProperties: false
|
||||
required:
|
||||
- file_id
|
||||
- filename
|
||||
- attributes
|
||||
- content
|
||||
title: VectorStoreFileContentsResponse
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: VectorStoreFileContentResponse
|
||||
description: >-
|
||||
Response from retrieving the contents of a vector store file.
|
||||
Represents the parsed content of a vector store file.
|
||||
OpenaiSearchVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10816,68 +10454,6 @@ components:
|
|||
- data
|
||||
title: ListDatasetsResponse
|
||||
description: Response from listing datasets.
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
Benchmark:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10945,47 +10521,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListBenchmarksResponse
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
BenchmarkConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -11847,6 +11382,109 @@ components:
|
|||
- hyperparam_search_config
|
||||
- logger_config
|
||||
title: SupervisedFineTuneRequest
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
responses:
|
||||
BadRequest400:
|
||||
description: The request was invalid or malformed
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# Kubernetes Deployment Guide
|
||||
|
||||
Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers both local development with Kind and production deployment on AWS EKS.
|
||||
Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers deployment using the Kubernetes operator to manage the Llama Stack server with Kind. The vLLM inference server is deployed manually.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
|
|
@ -110,115 +110,176 @@ spec:
|
|||
EOF
|
||||
```
|
||||
|
||||
### Step 3: Configure Llama Stack
|
||||
### Step 3: Install Kubernetes Operator
|
||||
|
||||
Update your run configuration:
|
||||
|
||||
```yaml
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: http://vllm-server.default.svc.cluster.local:8000/v1
|
||||
max_tokens: 4096
|
||||
api_token: fake
|
||||
```
|
||||
|
||||
Build container image:
|
||||
Install the Llama Stack Kubernetes operator to manage Llama Stack deployments:
|
||||
|
||||
```bash
|
||||
tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <<EOF
|
||||
FROM distribution-myenv:dev
|
||||
RUN apt-get update && apt-get install -y git
|
||||
RUN git clone https://github.com/meta-llama/llama-stack.git /app/llama-stack-source
|
||||
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml
|
||||
EOF
|
||||
podman build -f $tmp_dir/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s $tmp_dir
|
||||
# Install from the latest main branch
|
||||
kubectl apply -f https://raw.githubusercontent.com/llamastack/llama-stack-k8s-operator/main/release/operator.yaml
|
||||
|
||||
# Or install a specific version (e.g., v0.4.0)
|
||||
# kubectl apply -f https://raw.githubusercontent.com/llamastack/llama-stack-k8s-operator/v0.4.0/release/operator.yaml
|
||||
```
|
||||
|
||||
### Step 4: Deploy Llama Stack Server
|
||||
Verify the operator is running:
|
||||
|
||||
```bash
|
||||
kubectl get pods -n llama-stack-operator-system
|
||||
```
|
||||
|
||||
For more information about the operator, see the [llama-stack-k8s-operator repository](https://github.com/llamastack/llama-stack-k8s-operator).
|
||||
|
||||
### Step 4: Deploy Llama Stack Server using Operator
|
||||
|
||||
Create a `LlamaStackDistribution` custom resource to deploy the Llama Stack server. The operator will automatically create the necessary Deployment, Service, and other resources.
|
||||
You can optionally override the default `run.yaml` using `spec.server.userConfig` with a ConfigMap (see [userConfig spec](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md#userconfigspec)).
|
||||
|
||||
```yaml
|
||||
cat <<EOF | kubectl apply -f -
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
apiVersion: llamastack.io/v1alpha1
|
||||
kind: LlamaStackDistribution
|
||||
metadata:
|
||||
name: llama-pvc
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: llama-stack-server
|
||||
name: llamastack-vllm
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app.kubernetes.io/name: llama-stack
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app.kubernetes.io/name: llama-stack
|
||||
spec:
|
||||
containers:
|
||||
- name: llama-stack
|
||||
image: localhost/llama-stack-run-k8s:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["llama", "stack", "run", "/app/config.yaml"]
|
||||
ports:
|
||||
- containerPort: 5000
|
||||
volumeMounts:
|
||||
- name: llama-storage
|
||||
mountPath: /root/.llama
|
||||
volumes:
|
||||
- name: llama-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: llama-pvc
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: llama-stack-service
|
||||
spec:
|
||||
selector:
|
||||
app.kubernetes.io/name: llama-stack
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 5000
|
||||
targetPort: 5000
|
||||
type: ClusterIP
|
||||
server:
|
||||
distribution:
|
||||
name: starter
|
||||
containerSpec:
|
||||
port: 8321
|
||||
env:
|
||||
- name: VLLM_URL
|
||||
value: "http://vllm-server.default.svc.cluster.local:8000/v1"
|
||||
- name: VLLM_MAX_TOKENS
|
||||
value: "4096"
|
||||
- name: VLLM_API_TOKEN
|
||||
value: "fake"
|
||||
# Optional: override run.yaml from a ConfigMap using userConfig
|
||||
userConfig:
|
||||
configMap:
|
||||
name: llama-stack-config
|
||||
storage:
|
||||
size: "20Gi"
|
||||
mountPath: "/home/lls/.lls"
|
||||
EOF
|
||||
```
|
||||
|
||||
**Configuration Options:**
|
||||
|
||||
- `replicas`: Number of Llama Stack server instances to run
|
||||
- `server.distribution.name`: The distribution to use (e.g., `starter` for the starter distribution). See the [list of supported distributions](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/distributions.json) in the operator repository.
|
||||
- `server.distribution.image`: (Optional) Custom container image for non-supported distributions. Use this field when deploying a distribution that is not in the supported list. If specified, this takes precedence over `name`.
|
||||
- `server.containerSpec.port`: Port on which the Llama Stack server listens (default: 8321)
|
||||
- `server.containerSpec.env`: Environment variables to configure providers:
|
||||
- `server.userConfig`: (Optional) Override the default `run.yaml` using a ConfigMap. See [userConfig spec](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md#userconfigspec).
|
||||
- `server.storage.size`: Size of the persistent volume for model and data storage
|
||||
- `server.storage.mountPath`: Where to mount the storage in the container
|
||||
|
||||
**Note:** For a complete list of supported distributions, see [distributions.json](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/distributions.json) in the operator repository. To use a custom or non-supported distribution, set the `server.distribution.image` field with your container image instead of `server.distribution.name`.
|
||||
|
||||
The operator automatically creates:
|
||||
- A Deployment for the Llama Stack server
|
||||
- A Service to access the server
|
||||
- A PersistentVolumeClaim for storage
|
||||
- All necessary RBAC resources
|
||||
|
||||
|
||||
Check the status of your deployment:
|
||||
|
||||
```bash
|
||||
kubectl get llamastackdistribution
|
||||
kubectl describe llamastackdistribution llamastack-vllm
|
||||
```
|
||||
|
||||
### Step 5: Test Deployment
|
||||
|
||||
Wait for the Llama Stack server pod to be ready:
|
||||
|
||||
```bash
|
||||
# Port forward and test
|
||||
kubectl port-forward service/llama-stack-service 5000:5000
|
||||
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
|
||||
# Check the status of the LlamaStackDistribution
|
||||
kubectl get llamastackdistribution llamastack-vllm
|
||||
|
||||
# Check the pods created by the operator
|
||||
kubectl get pods -l app.kubernetes.io/name=llama-stack
|
||||
|
||||
# Wait for the pod to be ready
|
||||
kubectl wait --for=condition=ready pod -l app.kubernetes.io/name=llama-stack --timeout=300s
|
||||
```
|
||||
|
||||
Get the service name created by the operator (it typically follows the pattern `<llamastackdistribution-name>-service`):
|
||||
|
||||
```bash
|
||||
# List services to find the service name
|
||||
kubectl get services | grep llamastack
|
||||
|
||||
# Port forward and test (replace SERVICE_NAME with the actual service name)
|
||||
kubectl port-forward service/llamastack-vllm-service 8321:8321
|
||||
```
|
||||
|
||||
In another terminal, test the deployment:
|
||||
|
||||
```bash
|
||||
llama-stack-client --endpoint http://localhost:8321 inference chat-completion --message "hello, what model are you?"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Check pod status:**
|
||||
### vLLM Server Issues
|
||||
|
||||
**Check vLLM pod status:**
|
||||
```bash
|
||||
kubectl get pods -l app.kubernetes.io/name=vllm
|
||||
kubectl logs -l app.kubernetes.io/name=vllm
|
||||
```
|
||||
|
||||
**Test service connectivity:**
|
||||
**Test vLLM service connectivity:**
|
||||
```bash
|
||||
kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://vllm-server:8000/v1/models
|
||||
```
|
||||
|
||||
### Llama Stack Server Issues
|
||||
|
||||
**Check LlamaStackDistribution status:**
|
||||
```bash
|
||||
# Get detailed status
|
||||
kubectl describe llamastackdistribution llamastack-vllm
|
||||
|
||||
# Check for events
|
||||
kubectl get events --sort-by='.lastTimestamp' | grep llamastack-vllm
|
||||
```
|
||||
|
||||
**Check operator-managed pods:**
|
||||
```bash
|
||||
# List all pods managed by the operator
|
||||
kubectl get pods -l app.kubernetes.io/name=llama-stack
|
||||
|
||||
# Check pod logs (replace POD_NAME with actual pod name)
|
||||
kubectl logs -l app.kubernetes.io/name=llama-stack
|
||||
```
|
||||
|
||||
**Check operator status:**
|
||||
```bash
|
||||
# Verify the operator is running
|
||||
kubectl get pods -n llama-stack-operator-system
|
||||
|
||||
# Check operator logs if issues persist
|
||||
kubectl logs -n llama-stack-operator-system -l control-plane=controller-manager
|
||||
```
|
||||
|
||||
**Verify service connectivity:**
|
||||
```bash
|
||||
# Get the service endpoint
|
||||
kubectl get svc llamastack-vllm-service
|
||||
|
||||
# Test connectivity from within the cluster
|
||||
kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://llamastack-vllm-service:8321/health
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- **[Deployment Overview](/docs/deploying/)** - Overview of deployment options
|
||||
- **[Distributions](/docs/distributions)** - Understanding Llama Stack distributions
|
||||
- **[Configuration](/docs/distributions/configuration)** - Detailed configuration options
|
||||
- **[LlamaStack Operator](https://github.com/llamastack/llama-stack-k8s-operator)** - Overview of llama-stack kubernetes operator
|
||||
- **[LlamaStackDistribution](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md)** - API Spec of the llama-stack operator Custom Resource.
|
||||
|
|
|
|||
143
docs/docs/distributions/remote_hosted_distro/oci.md
Normal file
143
docs/docs/distributions/remote_hosted_distro/oci.md
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# OCI Distribution
|
||||
|
||||
The `llamastack/distribution-oci` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| files | `inline::localfs` |
|
||||
| inference | `remote::oci` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `OCI_AUTH_TYPE`: OCI authentication type (instance_principal or config_file) (default: `instance_principal`)
|
||||
- `OCI_REGION`: OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1) (default: ``)
|
||||
- `OCI_COMPARTMENT_OCID`: OCI compartment ID for the Generative AI service (default: ``)
|
||||
- `OCI_CONFIG_FILE_PATH`: OCI config file path (required if OCI_AUTH_TYPE is config_file) (default: `~/.oci/config`)
|
||||
- `OCI_CLI_PROFILE`: OCI CLI profile name to use from config file (default: `DEFAULT`)
|
||||
|
||||
|
||||
## Prerequisites
|
||||
### Oracle Cloud Infrastructure Setup
|
||||
|
||||
Before using the OCI Generative AI distribution, ensure you have:
|
||||
|
||||
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
|
||||
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
|
||||
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
|
||||
4. **Authentication**: Configure authentication using either:
|
||||
- **Instance Principal** (recommended for cloud-hosted deployments)
|
||||
- **API Key** (for on-premises or development environments)
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
#### Instance Principal Authentication (Recommended)
|
||||
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
|
||||
|
||||
Requirements:
|
||||
- Instance must be running in an Oracle Cloud Infrastructure compartment
|
||||
- Instance must have appropriate IAM policies to access Generative AI services
|
||||
|
||||
#### API Key Authentication
|
||||
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
|
||||
|
||||
### Required IAM Policies
|
||||
|
||||
Ensure your OCI user or instance has the following policy statements:
|
||||
|
||||
```
|
||||
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: OCI Generative AI
|
||||
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
|
||||
|
||||
- **Chat Completions**: Conversational AI with context awareness
|
||||
- **Text Generation**: Complete prompts and generate text content
|
||||
|
||||
#### Available Models
|
||||
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
|
||||
|
||||
### Safety: Llama Guard
|
||||
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
|
||||
- Content filtering and moderation
|
||||
- Policy compliance checking
|
||||
- Harmful content detection
|
||||
|
||||
### Vector Storage: Multiple Options
|
||||
The distribution supports several vector storage providers:
|
||||
- **FAISS**: Local in-memory vector search
|
||||
- **ChromaDB**: Distributed vector database
|
||||
- **PGVector**: PostgreSQL with vector extensions
|
||||
|
||||
### Additional Services
|
||||
- **Dataset I/O**: Local filesystem and Hugging Face integration
|
||||
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
|
||||
- **Evaluation**: Meta reference evaluation framework
|
||||
|
||||
## Running Llama Stack with OCI
|
||||
|
||||
You can run the OCI distribution via Docker or local virtual environment.
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Using Instance Principal (Recommended for Production)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=instance_principal
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
|
||||
```
|
||||
|
||||
#### Using API Key Authentication (Development)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=config_file
|
||||
export OCI_CONFIG_FILE_PATH=~/.oci/config
|
||||
export OCI_CLI_PROFILE=DEFAULT
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
|
||||
```
|
||||
|
||||
## Regional Endpoints
|
||||
|
||||
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
|
||||
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
|
||||
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
|
||||
3. **Permission Denied**: Check compartment permissions and Generative AI service access
|
||||
4. **Region Unavailable**: Verify the specified region supports Generative AI services
|
||||
|
||||
### Getting Help
|
||||
|
||||
For additional support:
|
||||
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
|
||||
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)
|
||||
41
docs/docs/providers/inference/remote_oci.mdx
Normal file
41
docs/docs/providers/inference/remote_oci.mdx
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
---
|
||||
description: |
|
||||
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
|
||||
Provider documentation
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
|
||||
sidebar_label: Remote - Oci
|
||||
title: remote::oci
|
||||
---
|
||||
|
||||
# remote::oci
|
||||
|
||||
## Description
|
||||
|
||||
|
||||
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
|
||||
Provider documentation
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `oci_auth_type` | `<class 'str'>` | No | instance_principal | OCI authentication type (must be one of: instance_principal, config_file) |
|
||||
| `oci_region` | `<class 'str'>` | No | us-ashburn-1 | OCI region (e.g., us-ashburn-1) |
|
||||
| `oci_compartment_id` | `<class 'str'>` | No | | OCI compartment ID for the Generative AI service |
|
||||
| `oci_config_file_path` | `<class 'str'>` | No | ~/.oci/config | OCI config file path (required if oci_auth_type is config_file) |
|
||||
| `oci_config_profile` | `<class 'str'>` | No | DEFAULT | OCI config profile (required if oci_auth_type is config_file) |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
|
||||
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
|
||||
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
|
||||
oci_region: ${env.OCI_REGION:=us-ashburn-1}
|
||||
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
|
||||
```
|
||||
1094
docs/static/deprecated-llama-stack-spec.yaml
vendored
1094
docs/static/deprecated-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
214
docs/static/experimental-llama-stack-spec.yaml
vendored
214
docs/static/experimental-llama-stack-spec.yaml
vendored
|
|
@ -162,7 +162,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterDatasetRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1beta/datasets/{dataset_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -219,7 +219,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -270,7 +270,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterBenchmarkRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -327,7 +327,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -936,68 +936,6 @@ components:
|
|||
- data
|
||||
title: ListDatasetsResponse
|
||||
description: Response from listing datasets.
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
Benchmark:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -1065,47 +1003,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListBenchmarksResponse
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
AggregationFunctionType:
|
||||
type: string
|
||||
enum:
|
||||
|
|
@ -2254,6 +2151,109 @@ components:
|
|||
- hyperparam_search_config
|
||||
- logger_config
|
||||
title: SupervisedFineTuneRequest
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
responses:
|
||||
BadRequest400:
|
||||
description: The request was invalid or malformed
|
||||
|
|
|
|||
454
docs/static/llama-stack-spec.yaml
vendored
454
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -960,7 +960,7 @@ paths:
|
|||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
returns all non-deprecated routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
|
|
@ -995,39 +995,6 @@ paths:
|
|||
description: List models using the OpenAI API.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Model.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Model'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Register model.
|
||||
description: >-
|
||||
Register model.
|
||||
|
||||
Register a model.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/models/{model_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1062,36 +1029,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Unregister model.
|
||||
description: >-
|
||||
Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
parameters:
|
||||
- name: model_id
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the model to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/moderations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1722,32 +1659,6 @@ paths:
|
|||
description: List all scoring functions.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Register a scoring function.
|
||||
description: Register a scoring function.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/scoring-functions/{scoring_fn_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1779,33 +1690,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Unregister a scoring function.
|
||||
description: Unregister a scoring function.
|
||||
parameters:
|
||||
- name: scoring_fn_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the scoring function to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/scoring/score:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1894,36 +1778,6 @@ paths:
|
|||
description: List all shields.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Shield.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Shield'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Register a shield.
|
||||
description: Register a shield.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterShieldRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/shields/{identifier}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1955,33 +1809,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Unregister a shield.
|
||||
description: Unregister a shield.
|
||||
parameters:
|
||||
- name: identifier
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the shield to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2077,32 +1904,6 @@ paths:
|
|||
description: List tool groups with optional provider.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Register a tool group.
|
||||
description: Register a tool group.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterToolGroupRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups/{toolgroup_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2134,32 +1935,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Unregister a tool group.
|
||||
description: Unregister a tool group.
|
||||
parameters:
|
||||
- name: toolgroup_id
|
||||
in: path
|
||||
description: The ID of the tool group to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tools:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2913,11 +2688,11 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A list of InterleavedContent representing the file contents.
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
|
||||
$ref: '#/components/schemas/VectorStoreFileContentResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
|
@ -5564,46 +5339,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RegisterModelRequest:
|
||||
type: object
|
||||
properties:
|
||||
model_id:
|
||||
type: string
|
||||
description: The identifier of the model to register.
|
||||
provider_model_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the model in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Any additional metadata for this model.
|
||||
model_type:
|
||||
$ref: '#/components/schemas/ModelType'
|
||||
description: The type of model to register.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
title: RegisterModelRequest
|
||||
Model:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -5661,6 +5396,15 @@ components:
|
|||
title: Model
|
||||
description: >-
|
||||
A model resource representing an AI model registered in Llama Stack.
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RunModerationRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6166,6 +5910,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
input:
|
||||
type: array
|
||||
items:
|
||||
|
|
@ -6524,6 +6273,11 @@ components:
|
|||
(Optional) Additional fields to include in the response.
|
||||
max_infer_iters:
|
||||
type: integer
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input
|
||||
|
|
@ -6605,6 +6359,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
additionalProperties: false
|
||||
required:
|
||||
- created_at
|
||||
|
|
@ -8399,61 +8158,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListScoringFunctionsResponse
|
||||
ParamType:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/StringType'
|
||||
- $ref: '#/components/schemas/NumberType'
|
||||
- $ref: '#/components/schemas/BooleanType'
|
||||
- $ref: '#/components/schemas/ArrayType'
|
||||
- $ref: '#/components/schemas/ObjectType'
|
||||
- $ref: '#/components/schemas/JsonType'
|
||||
- $ref: '#/components/schemas/UnionType'
|
||||
- $ref: '#/components/schemas/ChatCompletionInputType'
|
||||
- $ref: '#/components/schemas/CompletionInputType'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
string: '#/components/schemas/StringType'
|
||||
number: '#/components/schemas/NumberType'
|
||||
boolean: '#/components/schemas/BooleanType'
|
||||
array: '#/components/schemas/ArrayType'
|
||||
object: '#/components/schemas/ObjectType'
|
||||
json: '#/components/schemas/JsonType'
|
||||
union: '#/components/schemas/UnionType'
|
||||
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
||||
completion_input: '#/components/schemas/CompletionInputType'
|
||||
RegisterScoringFunctionRequest:
|
||||
type: object
|
||||
properties:
|
||||
scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the scoring function to register.
|
||||
description:
|
||||
type: string
|
||||
description: The description of the scoring function.
|
||||
return_type:
|
||||
$ref: '#/components/schemas/ParamType'
|
||||
description: The return type of the scoring function.
|
||||
provider_scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider scoring function to use for the scoring function.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the scoring function.
|
||||
params:
|
||||
$ref: '#/components/schemas/ScoringFnParams'
|
||||
description: >-
|
||||
The parameters for the scoring function for benchmark eval, these can
|
||||
be overridden for app eval.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- scoring_fn_id
|
||||
- description
|
||||
- return_type
|
||||
title: RegisterScoringFunctionRequest
|
||||
ScoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8629,35 +8333,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListShieldsResponse
|
||||
RegisterShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield to register.
|
||||
provider_shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
params:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The parameters of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- shield_id
|
||||
title: RegisterShieldRequest
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8918,37 +8593,6 @@ components:
|
|||
title: ListToolGroupsResponse
|
||||
description: >-
|
||||
Response containing a list of tool groups.
|
||||
RegisterToolGroupRequest:
|
||||
type: object
|
||||
properties:
|
||||
toolgroup_id:
|
||||
type: string
|
||||
description: The ID of the tool group to register.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the tool group.
|
||||
mcp_endpoint:
|
||||
$ref: '#/components/schemas/URL'
|
||||
description: >-
|
||||
The MCP endpoint to use for the tool group.
|
||||
args:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
A dictionary of arguments to pass to the tool group.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- toolgroup_id
|
||||
- provider_id
|
||||
title: RegisterToolGroupRequest
|
||||
Chunk:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9749,41 +9393,35 @@ components:
|
|||
title: VectorStoreContent
|
||||
description: >-
|
||||
Content item from a vector store file or search result.
|
||||
VectorStoreFileContentsResponse:
|
||||
VectorStoreFileContentResponse:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
object:
|
||||
type: string
|
||||
description: Unique identifier for the file
|
||||
filename:
|
||||
type: string
|
||||
description: Name of the file
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
const: vector_store.file_content.page
|
||||
default: vector_store.file_content.page
|
||||
description: >-
|
||||
Key-value attributes associated with the file
|
||||
content:
|
||||
The object type, which is always `vector_store.file_content.page`
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreContent'
|
||||
description: List of content items from the file
|
||||
description: Parsed content of the file
|
||||
has_more:
|
||||
type: boolean
|
||||
description: >-
|
||||
Indicates if there are more content pages to fetch
|
||||
next_page:
|
||||
type: string
|
||||
description: The token for the next page, if any
|
||||
additionalProperties: false
|
||||
required:
|
||||
- file_id
|
||||
- filename
|
||||
- attributes
|
||||
- content
|
||||
title: VectorStoreFileContentsResponse
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: VectorStoreFileContentResponse
|
||||
description: >-
|
||||
Response from retrieving the contents of a vector store file.
|
||||
Represents the parsed content of a vector store file.
|
||||
OpenaiSearchVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
668
docs/static/stainless-llama-stack-spec.yaml
vendored
668
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -963,7 +963,7 @@ paths:
|
|||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
returns all non-deprecated routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
|
|
@ -998,39 +998,6 @@ paths:
|
|||
description: List models using the OpenAI API.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Model.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Model'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Register model.
|
||||
description: >-
|
||||
Register model.
|
||||
|
||||
Register a model.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/models/{model_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1065,36 +1032,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: Unregister model.
|
||||
description: >-
|
||||
Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
parameters:
|
||||
- name: model_id
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the model to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/moderations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1725,32 +1662,6 @@ paths:
|
|||
description: List all scoring functions.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Register a scoring function.
|
||||
description: Register a scoring function.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/scoring-functions/{scoring_fn_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1782,33 +1693,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ScoringFunctions
|
||||
summary: Unregister a scoring function.
|
||||
description: Unregister a scoring function.
|
||||
parameters:
|
||||
- name: scoring_fn_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the scoring function to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/scoring/score:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -1897,36 +1781,6 @@ paths:
|
|||
description: List all shields.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: A Shield.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Shield'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Register a shield.
|
||||
description: Register a shield.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterShieldRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/shields/{identifier}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1958,33 +1812,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
summary: Unregister a shield.
|
||||
description: Unregister a shield.
|
||||
parameters:
|
||||
- name: identifier
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the shield to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2080,32 +1907,6 @@ paths:
|
|||
description: List tool groups with optional provider.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Register a tool group.
|
||||
description: Register a tool group.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RegisterToolGroupRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups/{toolgroup_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2137,32 +1938,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolGroups
|
||||
summary: Unregister a tool group.
|
||||
description: Unregister a tool group.
|
||||
parameters:
|
||||
- name: toolgroup_id
|
||||
in: path
|
||||
description: The ID of the tool group to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/tools:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2916,11 +2691,11 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A list of InterleavedContent representing the file contents.
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
|
||||
$ref: '#/components/schemas/VectorStoreFileContentResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
|
@ -3171,7 +2946,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterDatasetRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1beta/datasets/{dataset_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3228,7 +3003,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3279,7 +3054,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterBenchmarkRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -3336,7 +3111,7 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
deprecated: true
|
||||
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -6280,46 +6055,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RegisterModelRequest:
|
||||
type: object
|
||||
properties:
|
||||
model_id:
|
||||
type: string
|
||||
description: The identifier of the model to register.
|
||||
provider_model_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the model in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Any additional metadata for this model.
|
||||
model_type:
|
||||
$ref: '#/components/schemas/ModelType'
|
||||
description: The type of model to register.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
title: RegisterModelRequest
|
||||
Model:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6377,6 +6112,15 @@ components:
|
|||
title: Model
|
||||
description: >-
|
||||
A model resource representing an AI model registered in Llama Stack.
|
||||
ModelType:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
RunModerationRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6882,6 +6626,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
input:
|
||||
type: array
|
||||
items:
|
||||
|
|
@ -7240,6 +6989,11 @@ components:
|
|||
(Optional) Additional fields to include in the response.
|
||||
max_infer_iters:
|
||||
type: integer
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input
|
||||
|
|
@ -7321,6 +7075,11 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
(Optional) System message inserted into the model's context
|
||||
max_tool_calls:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Max number of total calls to built-in tools that can be processed
|
||||
in a response
|
||||
additionalProperties: false
|
||||
required:
|
||||
- created_at
|
||||
|
|
@ -9115,61 +8874,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListScoringFunctionsResponse
|
||||
ParamType:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/StringType'
|
||||
- $ref: '#/components/schemas/NumberType'
|
||||
- $ref: '#/components/schemas/BooleanType'
|
||||
- $ref: '#/components/schemas/ArrayType'
|
||||
- $ref: '#/components/schemas/ObjectType'
|
||||
- $ref: '#/components/schemas/JsonType'
|
||||
- $ref: '#/components/schemas/UnionType'
|
||||
- $ref: '#/components/schemas/ChatCompletionInputType'
|
||||
- $ref: '#/components/schemas/CompletionInputType'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
string: '#/components/schemas/StringType'
|
||||
number: '#/components/schemas/NumberType'
|
||||
boolean: '#/components/schemas/BooleanType'
|
||||
array: '#/components/schemas/ArrayType'
|
||||
object: '#/components/schemas/ObjectType'
|
||||
json: '#/components/schemas/JsonType'
|
||||
union: '#/components/schemas/UnionType'
|
||||
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
|
||||
completion_input: '#/components/schemas/CompletionInputType'
|
||||
RegisterScoringFunctionRequest:
|
||||
type: object
|
||||
properties:
|
||||
scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the scoring function to register.
|
||||
description:
|
||||
type: string
|
||||
description: The description of the scoring function.
|
||||
return_type:
|
||||
$ref: '#/components/schemas/ParamType'
|
||||
description: The return type of the scoring function.
|
||||
provider_scoring_fn_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider scoring function to use for the scoring function.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the scoring function.
|
||||
params:
|
||||
$ref: '#/components/schemas/ScoringFnParams'
|
||||
description: >-
|
||||
The parameters for the scoring function for benchmark eval, these can
|
||||
be overridden for app eval.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- scoring_fn_id
|
||||
- description
|
||||
- return_type
|
||||
title: RegisterScoringFunctionRequest
|
||||
ScoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9345,35 +9049,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListShieldsResponse
|
||||
RegisterShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield to register.
|
||||
provider_shield_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the shield in the provider.
|
||||
provider_id:
|
||||
type: string
|
||||
description: The identifier of the provider.
|
||||
params:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The parameters of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- shield_id
|
||||
title: RegisterShieldRequest
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9634,37 +9309,6 @@ components:
|
|||
title: ListToolGroupsResponse
|
||||
description: >-
|
||||
Response containing a list of tool groups.
|
||||
RegisterToolGroupRequest:
|
||||
type: object
|
||||
properties:
|
||||
toolgroup_id:
|
||||
type: string
|
||||
description: The ID of the tool group to register.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the tool group.
|
||||
mcp_endpoint:
|
||||
$ref: '#/components/schemas/URL'
|
||||
description: >-
|
||||
The MCP endpoint to use for the tool group.
|
||||
args:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
A dictionary of arguments to pass to the tool group.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- toolgroup_id
|
||||
- provider_id
|
||||
title: RegisterToolGroupRequest
|
||||
Chunk:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10465,41 +10109,35 @@ components:
|
|||
title: VectorStoreContent
|
||||
description: >-
|
||||
Content item from a vector store file or search result.
|
||||
VectorStoreFileContentsResponse:
|
||||
VectorStoreFileContentResponse:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
object:
|
||||
type: string
|
||||
description: Unique identifier for the file
|
||||
filename:
|
||||
type: string
|
||||
description: Name of the file
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
const: vector_store.file_content.page
|
||||
default: vector_store.file_content.page
|
||||
description: >-
|
||||
Key-value attributes associated with the file
|
||||
content:
|
||||
The object type, which is always `vector_store.file_content.page`
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreContent'
|
||||
description: List of content items from the file
|
||||
description: Parsed content of the file
|
||||
has_more:
|
||||
type: boolean
|
||||
description: >-
|
||||
Indicates if there are more content pages to fetch
|
||||
next_page:
|
||||
type: string
|
||||
description: The token for the next page, if any
|
||||
additionalProperties: false
|
||||
required:
|
||||
- file_id
|
||||
- filename
|
||||
- attributes
|
||||
- content
|
||||
title: VectorStoreFileContentsResponse
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: VectorStoreFileContentResponse
|
||||
description: >-
|
||||
Response from retrieving the contents of a vector store file.
|
||||
Represents the parsed content of a vector store file.
|
||||
OpenaiSearchVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10816,68 +10454,6 @@ components:
|
|||
- data
|
||||
title: ListDatasetsResponse
|
||||
description: Response from listing datasets.
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
Benchmark:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10945,47 +10521,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListBenchmarksResponse
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
BenchmarkConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -11847,6 +11382,109 @@ components:
|
|||
- hyperparam_search_config
|
||||
- logger_config
|
||||
title: SupervisedFineTuneRequest
|
||||
DataSource:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/URIDataSource'
|
||||
- $ref: '#/components/schemas/RowsDataSource'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
uri: '#/components/schemas/URIDataSource'
|
||||
rows: '#/components/schemas/RowsDataSource'
|
||||
RegisterDatasetRequest:
|
||||
type: object
|
||||
properties:
|
||||
purpose:
|
||||
type: string
|
||||
enum:
|
||||
- post-training/messages
|
||||
- eval/question-answer
|
||||
- eval/messages-answer
|
||||
description: >-
|
||||
The purpose of the dataset. One of: - "post-training/messages": The dataset
|
||||
contains a messages column with list of messages for post-training. {
|
||||
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
|
||||
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
|
||||
contains a question column and an answer column for evaluation. { "question":
|
||||
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
|
||||
The dataset contains a messages column with list of messages and an answer
|
||||
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
|
||||
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
|
||||
Doe. How can I help you today?"}, {"role": "user", "content": "What's
|
||||
my name?"}, ], "answer": "John Doe" }
|
||||
source:
|
||||
$ref: '#/components/schemas/DataSource'
|
||||
description: >-
|
||||
The data source of the dataset. Ensure that the data source schema is
|
||||
compatible with the purpose of the dataset. Examples: - { "type": "uri",
|
||||
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
|
||||
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
|
||||
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
|
||||
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
|
||||
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
|
||||
} ] }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The metadata for the dataset. - E.g. {"description": "My dataset"}.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset. If not provided, an ID will be generated.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- purpose
|
||||
- source
|
||||
title: RegisterDatasetRequest
|
||||
RegisterBenchmarkRequest:
|
||||
type: object
|
||||
properties:
|
||||
benchmark_id:
|
||||
type: string
|
||||
description: The ID of the benchmark to register.
|
||||
dataset_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the dataset to use for the benchmark.
|
||||
scoring_functions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The scoring functions to use for the benchmark.
|
||||
provider_benchmark_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider benchmark to use for the benchmark.
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for the benchmark.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: The metadata to use for the benchmark.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- benchmark_id
|
||||
- dataset_id
|
||||
- scoring_functions
|
||||
title: RegisterBenchmarkRequest
|
||||
responses:
|
||||
BadRequest400:
|
||||
description: The request was invalid or malformed
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ exclude = [
|
|||
"^src/llama_stack/providers/remote/agents/sample/",
|
||||
"^src/llama_stack/providers/remote/datasetio/huggingface/",
|
||||
"^src/llama_stack/providers/remote/datasetio/nvidia/",
|
||||
"^src/llama_stack/providers/remote/inference/oci/",
|
||||
"^src/llama_stack/providers/remote/inference/bedrock/",
|
||||
"^src/llama_stack/providers/remote/inference/nvidia/",
|
||||
"^src/llama_stack/providers/remote/inference/passthrough/",
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class Agents(Protocol):
|
|||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a model response.
|
||||
|
||||
|
|
@ -97,6 +98,7 @@ class Agents(Protocol):
|
|||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param usage: (Optional) Token usage information for the response
|
||||
:param instructions: (Optional) System message inserted into the model's context
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
|
@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
max_tool_calls: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
|
@ -235,7 +235,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator):
|
||||
async for chunk in event_generator:
|
||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
||||
yield LogEvent(event.delta, color="yellow", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
yield LogEvent("")
|
||||
else:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
||||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
|
|
@ -15,28 +15,18 @@ from typing import (
|
|||
)
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import MetricResponseMixin, Order
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.responses import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
|
|
@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
|
|||
content: InterleavedContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionMessage(BaseModel):
|
||||
"""A message containing the model's (assistant) response in a chat conversation.
|
||||
|
||||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param stop_reason: Reason why the model stopped generating. Options are:
|
||||
- `StopReason.end_of_turn`: The model finished generating the entire response.
|
||||
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
|
||||
- `StopReason.out_of_tokens`: The model ran out of token budget.
|
||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
"""Response from a tool invocation.
|
||||
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was invoked
|
||||
:param content: The response content from the tool
|
||||
:param metadata: (Optional) Additional metadata about the tool response
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class ToolChoice(Enum):
|
||||
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
||||
|
||||
|
|
@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
|
|||
progress = "progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEvent(BaseModel):
|
||||
"""An event during chat completion generation.
|
||||
|
||||
:param event_type: Type of the event
|
||||
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
:param stop_reason: Optional reason why generation stopped, if complete
|
||||
"""
|
||||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: ContentDelta
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(StrEnum):
|
||||
"""Types of formats for structured (guided) decoding.
|
||||
|
||||
|
|
@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
|
|||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
:param stop_reason: Reason why generation stopped
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
:param stop_reason: Optional reason why generation stopped, if complete
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
delta: str
|
||||
stop_reason: StopReason | None = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
class SystemMessageBehavior(Enum):
|
||||
"""Config for how to override the default system prompt.
|
||||
|
||||
|
|
@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
|
|||
replace = "replace"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration for tool use.
|
||||
|
||||
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
|
||||
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
||||
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||
"""
|
||||
|
||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.tool_choice, str):
|
||||
try:
|
||||
self.tool_choice = ToolChoice[self.tool_choice]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
"""
|
||||
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
"""Response containing generated embeddings.
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class Inspect(Protocol):
|
|||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -158,7 +158,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
|
|
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
@ -85,7 +85,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -167,7 +167,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
|
|||
|
|
@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentsResponse(BaseModel):
|
||||
"""Response from retrieving the contents of a vector store file.
|
||||
class VectorStoreFileContentResponse(BaseModel):
|
||||
"""Represents the parsed content of a vector store file.
|
||||
|
||||
:param file_id: Unique identifier for the file
|
||||
:param filename: Name of the file
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param content: List of content items from the file
|
||||
:param object: The object type, which is always `vector_store.file_content.page`
|
||||
:param data: Parsed content of the file
|
||||
:param has_more: Indicates if there are more content pages to fetch
|
||||
:param next_page: The token for the next page, if any
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
attributes: dict[str, Any]
|
||||
content: list[VectorStoreContent]
|
||||
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
|
||||
data: list[VectorStoreContent]
|
||||
has_more: bool
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -732,12 +732,12 @@ class VectorIO(Protocol):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A list of InterleavedContent representing the file contents.
|
||||
:returns: A VectorStoreFileContentResponse representing the file contents.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
|
|||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
# Default: only non-deprecated APIs
|
||||
return not webmethod.deprecated
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
|
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyStaticConfig,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
|
|
@ -338,7 +338,7 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
|
|
@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
|
|
|
|||
7
src/llama_stack/distributions/oci/__init__.py
Normal file
7
src/llama_stack/distributions/oci/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .oci import get_distribution_template # noqa: F401
|
||||
35
src/llama_stack/distributions/oci/build.yaml
Normal file
35
src/llama_stack/distributions/oci/build.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
|
||||
inference with scalable cloud services
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::oci
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# OCI Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Prerequisites
|
||||
### Oracle Cloud Infrastructure Setup
|
||||
|
||||
Before using the OCI Generative AI distribution, ensure you have:
|
||||
|
||||
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
|
||||
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
|
||||
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
|
||||
4. **Authentication**: Configure authentication using either:
|
||||
- **Instance Principal** (recommended for cloud-hosted deployments)
|
||||
- **API Key** (for on-premises or development environments)
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
#### Instance Principal Authentication (Recommended)
|
||||
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
|
||||
|
||||
Requirements:
|
||||
- Instance must be running in an Oracle Cloud Infrastructure compartment
|
||||
- Instance must have appropriate IAM policies to access Generative AI services
|
||||
|
||||
#### API Key Authentication
|
||||
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
|
||||
|
||||
### Required IAM Policies
|
||||
|
||||
Ensure your OCI user or instance has the following policy statements:
|
||||
|
||||
```
|
||||
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: OCI Generative AI
|
||||
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
|
||||
|
||||
- **Chat Completions**: Conversational AI with context awareness
|
||||
- **Text Generation**: Complete prompts and generate text content
|
||||
|
||||
#### Available Models
|
||||
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
|
||||
|
||||
### Safety: Llama Guard
|
||||
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
|
||||
- Content filtering and moderation
|
||||
- Policy compliance checking
|
||||
- Harmful content detection
|
||||
|
||||
### Vector Storage: Multiple Options
|
||||
The distribution supports several vector storage providers:
|
||||
- **FAISS**: Local in-memory vector search
|
||||
- **ChromaDB**: Distributed vector database
|
||||
- **PGVector**: PostgreSQL with vector extensions
|
||||
|
||||
### Additional Services
|
||||
- **Dataset I/O**: Local filesystem and Hugging Face integration
|
||||
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
|
||||
- **Evaluation**: Meta reference evaluation framework
|
||||
|
||||
## Running Llama Stack with OCI
|
||||
|
||||
You can run the OCI distribution via Docker or local virtual environment.
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Using Instance Principal (Recommended for Production)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=instance_principal
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
|
||||
```
|
||||
|
||||
#### Using API Key Authentication (Development)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=config_file
|
||||
export OCI_CONFIG_FILE_PATH=~/.oci/config
|
||||
export OCI_CLI_PROFILE=DEFAULT
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
|
||||
```
|
||||
|
||||
## Regional Endpoints
|
||||
|
||||
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
|
||||
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
|
||||
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
|
||||
3. **Permission Denied**: Check compartment permissions and Generative AI service access
|
||||
4. **Region Unavailable**: Verify the specified region supports Generative AI services
|
||||
|
||||
### Getting Help
|
||||
|
||||
For additional support:
|
||||
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
|
||||
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)
|
||||
108
src/llama_stack/distributions/oci/oci.py
Normal file
108
src/llama_stack/distributions/oci/oci.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type="remote::oci")],
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="oci",
|
||||
provider_type="remote::oci",
|
||||
config=OCIConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"OCI_AUTH_TYPE": (
|
||||
"instance_principal",
|
||||
"OCI authentication type (instance_principal or config_file)",
|
||||
),
|
||||
"OCI_REGION": (
|
||||
"",
|
||||
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
|
||||
),
|
||||
"OCI_COMPARTMENT_OCID": (
|
||||
"",
|
||||
"OCI compartment ID for the Generative AI service",
|
||||
),
|
||||
"OCI_CONFIG_FILE_PATH": (
|
||||
"~/.oci/config",
|
||||
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
|
||||
),
|
||||
"OCI_CLI_PROFILE": (
|
||||
"DEFAULT",
|
||||
"OCI CLI profile name to use from config file",
|
||||
),
|
||||
},
|
||||
)
|
||||
136
src/llama_stack/distributions/oci/run.yaml
Normal file
136
src/llama_stack/distributions/oci/run.yaml
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
version: 2
|
||||
image_name: oci
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: oci
|
||||
provider_type: remote::oci
|
||||
config:
|
||||
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
|
||||
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
|
||||
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
|
||||
oci_region: ${env.OCI_REGION:=us-ashburn-1}
|
||||
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ from pathlib import Path
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import json
|
|||
import re
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
|
||||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
from ..datatypes import RecursiveType
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import textwrap
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
|
|
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
include,
|
||||
max_infer_iters,
|
||||
guardrails,
|
||||
max_tool_calls,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
|
|
|
|||
|
|
@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
|
|||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
|
|||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
if max_tool_calls is not None and max_tool_calls < 1:
|
||||
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
conversation=conversation,
|
||||
|
|
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
|
|||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
|
|
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
|
|||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
|
|||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
|
|||
self.safety_api = safety_api
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.prompt = prompt
|
||||
# System message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
# Max number of total calls to built-in tools that can be processed in a response
|
||||
self.max_tool_calls = max_tool_calls
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
|
|
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
|
|||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
# system message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
# Track total calls made to built-in tools
|
||||
self.accumulated_builtin_tool_calls = 0
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
|
|
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
|
|||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
prompt=self.prompt,
|
||||
max_tool_calls=self.max_tool_calls,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
|
|||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
||||
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
||||
break
|
||||
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
|
|
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
|
|||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
# Track number of calls made to built-in and mcp tools
|
||||
self.accumulated_builtin_tool_calls += 1
|
||||
|
||||
# Execute function tool calls (client-side)
|
||||
for tool_call in function_tool_calls:
|
||||
# Find the item_id for this tool call from our tracking dictionary
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
|
|||
from llama_stack.apis.inference import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
|
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
|
|||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -157,55 +146,56 @@ class LlamaGenerator:
|
|||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
request: OpenAIChatCompletionRequestWithExtraBody,
|
||||
raw_messages: list,
|
||||
):
|
||||
"""Generate chat completion using OpenAI request format.
|
||||
|
||||
Args:
|
||||
request: OpenAI chat completion request
|
||||
raw_messages: Pre-converted list of RawMessage objects
|
||||
"""
|
||||
|
||||
# Determine tool prompt format
|
||||
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
|
||||
|
||||
# Prepare sampling params
|
||||
sampling_params = SamplingParams()
|
||||
if request.temperature is not None or request.top_p is not None:
|
||||
sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=request.temperature if request.temperature is not None else 1.0,
|
||||
top_p=request.top_p if request.top_p is not None else 1.0,
|
||||
)
|
||||
if request.max_tokens:
|
||||
sampling_params.max_tokens = request.max_tokens
|
||||
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
|
||||
# Get logits processor for response format
|
||||
logits_processor = None
|
||||
if request.response_format:
|
||||
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
|
||||
# Extract the actual schema from OpenAIJSONSchema TypedDict
|
||||
schema_dict = request.response_format.json_schema.get("schema") or {}
|
||||
json_schema_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema,
|
||||
json_schema=schema_dict,
|
||||
)
|
||||
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
||||
|
||||
# Generate
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
logprobs=False,
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
JsonCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
|
||||
"""Convert OpenAI tool format to ToolDefinition format."""
|
||||
# OpenAI tools have function.name and function.parameters
|
||||
return ToolDefinition(
|
||||
tool_name=tool.function.name,
|
||||
description=tool.function.description or "",
|
||||
parameters=tool.function.parameters or {},
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice, tools) -> str:
|
||||
"""Generate prompt text for tool_choice behavior."""
|
||||
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required or tool_choice == "required":
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
elif tool_choice == ToolChoice.none or tool_choice == "none":
|
||||
return ""
|
||||
else:
|
||||
# Specific tool specified
|
||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||
|
||||
|
||||
def _raw_content_as_str(content) -> str:
|
||||
"""Convert RawContent to string for system messages."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, RawTextItem):
|
||||
return content.text
|
||||
elif isinstance(content, list):
|
||||
return "\n".join(_raw_content_as_str(c) for c in content)
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
|
||||
def _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages: list[RawMessage],
|
||||
tools: list,
|
||||
tool_choice,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
|
||||
messages = raw_messages.copy()
|
||||
existing_system_message = None
|
||||
if messages and messages[0].role == "system":
|
||||
existing_system_message = messages.pop(0)
|
||||
|
||||
sys_content = ""
|
||||
|
||||
# Add tool definitions first (if present)
|
||||
if tools:
|
||||
# Convert OpenAI tools to ToolDefinitions
|
||||
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||
|
||||
# For OpenAI format, all tools are custom (have string names)
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
tool_template = tool_gen.gen(tool_definitions)
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
# Add default system prompt
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
sys_content += default_template.render()
|
||||
|
||||
# Add existing system message if present
|
||||
if existing_system_message:
|
||||
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
# Add tool choice prompt if needed
|
||||
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
# Create new system message
|
||||
new_system_message = RawMessage(
|
||||
role="system",
|
||||
content=[RawTextItem(text=sys_content.strip())],
|
||||
)
|
||||
|
||||
return [new_system_message] + messages
|
||||
|
||||
|
||||
def _augment_raw_messages_for_tools_llama_4(
|
||||
raw_messages: list[RawMessage],
|
||||
tools: list,
|
||||
tool_choice,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
|
||||
messages = raw_messages.copy()
|
||||
existing_system_message = None
|
||||
if messages and messages[0].role == "system":
|
||||
existing_system_message = messages.pop(0)
|
||||
|
||||
sys_content = ""
|
||||
|
||||
# Add tool definitions if present
|
||||
if tools:
|
||||
# Convert OpenAI tools to ToolDefinitions
|
||||
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||
|
||||
# Use python_list format for Llama 4
|
||||
tool_gen = PythonListCustomToolGeneratorLlama4()
|
||||
system_prompt = None
|
||||
if existing_system_message:
|
||||
system_prompt = _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
tool_template = tool_gen.gen(tool_definitions, system_prompt)
|
||||
sys_content = tool_template.render()
|
||||
elif existing_system_message:
|
||||
# No tools, just use existing system message
|
||||
sys_content = _raw_content_as_str(existing_system_message.content)
|
||||
|
||||
# Add tool choice prompt if needed
|
||||
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
if sys_content:
|
||||
new_system_message = RawMessage(
|
||||
role="system",
|
||||
content=[RawTextItem(text=sys_content.strip())],
|
||||
)
|
||||
return [new_system_message] + messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def augment_raw_messages_for_tools(
|
||||
raw_messages: list[RawMessage],
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
llama_model,
|
||||
) -> list[RawMessage]:
|
||||
"""Augment raw messages with tool definitions based on model family."""
|
||||
if not params.tools:
|
||||
return raw_messages
|
||||
|
||||
# Determine augmentation strategy based on model family
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# Llama 3.1 and Llama 3.2 multimodal use JSON format
|
||||
return _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
elif llama_model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# Llama 3.2/3.3/4 use python_list format
|
||||
return _augment_raw_messages_for_tools_llama_4(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
else:
|
||||
# Default to Llama 3.1 style
|
||||
return _augment_raw_messages_for_tools_llama_3_1(
|
||||
raw_messages,
|
||||
params.tools,
|
||||
params.tool_choice,
|
||||
)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
|
@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl(
|
|||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
params=OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model_id,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
|
||||
max_tokens=20,
|
||||
)
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
|
|
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
self.check_model(params)
|
||||
|
||||
# Convert OpenAI messages to RawMessages
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
decode_assistant_message,
|
||||
)
|
||||
|
||||
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
||||
|
||||
# Augment messages with tool definitions if tools are present
|
||||
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
|
||||
|
||||
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
||||
if isinstance(self.generator, LlamaGenerator):
|
||||
generator = self.generator.chat_completion(params, raw_messages)
|
||||
else:
|
||||
# Model parallel: submit task to process group
|
||||
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
||||
|
||||
# Check if streaming is requested
|
||||
if params.stream:
|
||||
return self._stream_chat_completion(generator, params)
|
||||
|
||||
# Non-streaming: collect all generated text
|
||||
generated_text = ""
|
||||
for result_batch in generator:
|
||||
for result in result_batch:
|
||||
if not result.ignore_token and result.source == "output":
|
||||
generated_text += result.text
|
||||
|
||||
# Decode assistant message to extract tool calls and determine stop_reason
|
||||
# Default to end_of_turn if generation completed normally
|
||||
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||
|
||||
# Convert tool calls to OpenAI format
|
||||
openai_tool_calls = None
|
||||
if decoded_message.tool_calls:
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
)
|
||||
|
||||
openai_tool_calls = [
|
||||
OpenAIChatCompletionToolCall(
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=str(tc.tool_name),
|
||||
arguments=tc.arguments,
|
||||
),
|
||||
)
|
||||
for tc in decoded_message.tool_calls
|
||||
]
|
||||
|
||||
# Determine finish_reason based on whether tool calls are present
|
||||
finish_reason = "tool_calls" if openai_tool_calls else "stop"
|
||||
|
||||
# Extract content from decoded message
|
||||
content = ""
|
||||
if isinstance(decoded_message.content, str):
|
||||
content = decoded_message.content
|
||||
elif isinstance(decoded_message.content, list):
|
||||
for item in decoded_message.content:
|
||||
if isinstance(item, RawTextItem):
|
||||
content += item.text
|
||||
|
||||
# Create OpenAI response
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=response_id,
|
||||
object="chat.completion",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=openai_tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0, # TODO: calculate properly
|
||||
completion_tokens=0, # TODO: calculate properly
|
||||
total_tokens=0, # TODO: calculate properly
|
||||
),
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self,
|
||||
generator,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream chat completion chunks as they're generated."""
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
generated_text = ""
|
||||
|
||||
# Yield chunks as tokens are generated
|
||||
for result_batch in generator:
|
||||
for result in result_batch:
|
||||
if result.ignore_token or result.source != "output":
|
||||
continue
|
||||
|
||||
generated_text += result.text
|
||||
|
||||
# Yield delta chunk with the new text
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(
|
||||
role="assistant",
|
||||
content=result.text,
|
||||
),
|
||||
finish_reason="",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
# After generation completes, decode the full message to extract tool calls
|
||||
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||
|
||||
# If tool calls are present, yield a final chunk with tool_calls
|
||||
if decoded_message.tool_calls:
|
||||
openai_tool_calls = [
|
||||
OpenAIChatCompletionToolCall(
|
||||
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=str(tc.tool_name),
|
||||
arguments=tc.arguments,
|
||||
),
|
||||
)
|
||||
for tc in decoded_message.tool_calls
|
||||
]
|
||||
|
||||
# Yield chunk with tool_calls
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(
|
||||
role="assistant",
|
||||
tool_calls=openai_tool_calls,
|
||||
),
|
||||
finish_reason="",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Yield final chunk with finish_reason
|
||||
final_chunk = OpenAIChatCompletionChunk(
|
||||
id=response_id,
|
||||
object="chat.completion.chunk",
|
||||
created=created,
|
||||
model=params.model,
|
||||
choices=[
|
||||
OpenAIChunkChoice(
|
||||
index=0,
|
||||
delta=OpenAIChoiceDelta(),
|
||||
finish_reason=finish_reason,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield final_chunk
|
||||
|
|
|
|||
|
|
@ -4,17 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
|
@ -23,12 +18,14 @@ class ModelRunner:
|
|||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
task_type = task[0]
|
||||
if task_type == "chat_completion":
|
||||
# task[1] is [params, raw_messages]
|
||||
params, raw_messages = task[1]
|
||||
return self.llama.chat_completion(params, raw_messages)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
raise ValueError(f"Unexpected task type {task_type}")
|
||||
|
||||
|
||||
def init_model_cb(
|
||||
|
|
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
yield from gen
|
||||
|
|
|
|||
|
|
@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
task: tuple[str, list]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
|
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
req: tuple[str, list],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
|||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
|
|
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
|
|||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -297,6 +297,20 @@ Available Models:
|
|||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||
Provider documentation
|
||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::oci",
|
||||
adapter_type="oci",
|
||||
pip_packages=["oci"],
|
||||
module="llama_stack.providers.remote.inference.oci",
|
||||
config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator",
|
||||
description="""
|
||||
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
|
||||
Provider documentation
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import OCIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
|
||||
from .oci import OCIInferenceAdapter
|
||||
|
||||
adapter = OCIInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, override
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
import requests
|
||||
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
|
||||
|
||||
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
|
||||
|
||||
|
||||
class HttpxOciAuth(httpx.Auth):
|
||||
"""
|
||||
Custom HTTPX authentication class that implements OCI request signing.
|
||||
|
||||
This class handles the authentication flow for HTTPX requests by signing them
|
||||
using the OCI Signer, which adds the necessary authentication headers for
|
||||
OCI API calls.
|
||||
|
||||
Attributes:
|
||||
signer (oci.signer.Signer): The OCI signer instance used for request signing
|
||||
"""
|
||||
|
||||
def __init__(self, signer: OciAuthSigner):
|
||||
self.signer = signer
|
||||
|
||||
@override
|
||||
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
# Read the request content to handle streaming requests properly
|
||||
try:
|
||||
content = request.content
|
||||
except httpx.RequestNotRead:
|
||||
# For streaming requests, we need to read the content first
|
||||
content = request.read()
|
||||
|
||||
req = requests.Request(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
data=content,
|
||||
)
|
||||
prepared_request = req.prepare()
|
||||
|
||||
# Sign the request using the OCI Signer
|
||||
self.signer.do_request_sign(prepared_request) # type: ignore
|
||||
|
||||
# Update the original HTTPX request with the signed headers
|
||||
request.headers.update(prepared_request.headers)
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class OciInstancePrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, **kwargs: Mapping[str, Any]):
|
||||
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
|
||||
|
||||
|
||||
class OciUserPrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
|
||||
config = oci.config.from_file(config_file, profile_name)
|
||||
oci.config.validate_config(config) # type: ignore
|
||||
key_content = ""
|
||||
with open(config["key_file"]) as f:
|
||||
key_content = f.read()
|
||||
|
||||
self.signer = oci.signer.Signer(
|
||||
tenancy=config["tenancy"],
|
||||
user=config["user"],
|
||||
fingerprint=config["fingerprint"],
|
||||
private_key_file_location=config.get("key_file"),
|
||||
pass_phrase="none", # type: ignore
|
||||
private_key_content=key_content,
|
||||
)
|
||||
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OCIProviderDataValidator(BaseModel):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
)
|
||||
oci_region: str = Field(
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str | None = Field(
|
||||
default="~/.oci/config",
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str | None = Field(
|
||||
default="DEFAULT",
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OCIConfig(RemoteInferenceProviderConfig):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
|
||||
)
|
||||
oci_region: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
|
||||
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
|
||||
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
|
||||
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
|
||||
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"oci_auth_type": oci_auth_type,
|
||||
"oci_config_file_path": oci_config_file_path,
|
||||
"oci_config_profile": oci_config_profile,
|
||||
"oci_region": oci_region,
|
||||
"oci_compartment_id": oci_compartment_id,
|
||||
}
|
||||
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
from oci.generative_ai.generative_ai_client import GenerativeAiClient
|
||||
from oci.generative_ai.models import ModelCollection
|
||||
from openai._base_client import DefaultAsyncHttpxClient
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::oci")
|
||||
|
||||
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
|
||||
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
|
||||
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
|
||||
DEFAULT_OCI_REGION = "us-ashburn-1"
|
||||
|
||||
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
|
||||
|
||||
|
||||
class OCIInferenceAdapter(OpenAIMixin):
|
||||
config: OCIConfig
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize and validate OCI configuration."""
|
||||
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
|
||||
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
|
||||
)
|
||||
|
||||
if not self.config.oci_compartment_id:
|
||||
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
region = self.config.oci_region or DEFAULT_OCI_REGION
|
||||
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
# OCI doesn't use API keys, it uses request signing
|
||||
return "<NOTUSED>"
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
|
||||
"""
|
||||
auth = self._get_auth()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
return {
|
||||
"http_client": DefaultAsyncHttpxClient(
|
||||
auth=auth,
|
||||
headers={
|
||||
"CompartmentId": compartment_id,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
return None
|
||||
|
||||
def _get_oci_config(self) -> dict:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
config = {"region": self.config.oci_region}
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
|
||||
if not config.get("region"):
|
||||
raise ValueError(
|
||||
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def _get_auth(self) -> httpx.Auth:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return OciInstancePrincipalAuth()
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
return OciUserPrincipalAuth(
|
||||
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
List available models from OCI Generative AI service.
|
||||
"""
|
||||
oci_config = self._get_oci_config()
|
||||
oci_signer = self._get_oci_signer()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
if oci_signer is None:
|
||||
client = GenerativeAiClient(config=oci_config)
|
||||
else:
|
||||
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
||||
|
||||
models: ModelCollection = client.list_models(
|
||||
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
||||
).data
|
||||
|
||||
seen_models = set()
|
||||
model_ids = []
|
||||
for model in models.items:
|
||||
if model.time_deprecated or model.time_on_demand_retired:
|
||||
continue
|
||||
|
||||
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||
continue
|
||||
|
||||
# Use display_name + model_type as the key to avoid conflicts
|
||||
model_key = (model.display_name, ModelType.llm)
|
||||
if model_key in seen_models:
|
||||
continue
|
||||
|
||||
seen_models.add(model_key)
|
||||
model_ids.append(model.display_name)
|
||||
|
||||
return model_ids
|
||||
|
||||
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
||||
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
||||
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|
||||
|
|
@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
|
|||
import litellm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
|
|
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
|
|
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
from typing import Any
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||
]
|
||||
if fmt := request.response_format:
|
||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
# Convert to dict for manipulation
|
||||
fmt_dict = dict(fmt.json_schema)
|
||||
name = fmt_dict["title"]
|
||||
del fmt_dict["title"]
|
||||
fmt_dict["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt_dict,
|
||||
"strict": self.json_schema_strict,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": self.get_api_key(),
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
|
|||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIFile,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
|
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
|
|||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: list[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
content: RawContent
|
||||
|
||||
|
|
@ -103,28 +88,6 @@ def interleaved_content_as_str(
|
|||
return _process(content)
|
||||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
content = await interleaved_content_convert_to_raw(m.content)
|
||||
d = m.model_dump()
|
||||
d["content"] = content
|
||||
messages.append(RawMessage(**d))
|
||||
|
||||
d = request.model_dump()
|
||||
d["messages"] = messages
|
||||
request = ChatCompletionRequestWithRawContent(**d)
|
||||
else:
|
||||
d = request.model_dump()
|
||||
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
||||
request = CompletionRequestWithRawContent(**d)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
async def interleaved_content_convert_to_raw(
|
||||
content: InterleavedContent,
|
||||
) -> RawContent:
|
||||
|
|
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
|
|||
return await _localize_single(content)
|
||||
|
||||
|
||||
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
|
||||
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
|
||||
if isinstance(message, OpenAIUserMessageParam):
|
||||
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||
return RawMessage(role="user", content=content)
|
||||
elif isinstance(message, OpenAISystemMessageParam):
|
||||
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||
return RawMessage(role="system", content=content)
|
||||
elif isinstance(message, OpenAIAssistantMessageParam):
|
||||
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
|
||||
tool_calls = []
|
||||
if message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
if tc.function:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=tc.id or "",
|
||||
tool_name=tc.function.name or "",
|
||||
arguments=tc.function.arguments or "{}",
|
||||
)
|
||||
)
|
||||
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
|
||||
elif isinstance(message, OpenAIToolMessageParam):
|
||||
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||
return RawMessage(role="tool", content=content)
|
||||
else:
|
||||
# Handle OpenAIDeveloperMessageParam if needed
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedContent):
|
||||
def _has_media_content(c):
|
||||
return isinstance(c, ImageContentItem)
|
||||
|
|
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
|
|||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: list[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||
if uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
len(model_input.tokens),
|
||||
)
|
||||
|
||||
|
||||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
llama_model: str,
|
||||
) -> list[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
"""
|
||||
assert llama_model is not None, "llama_model is required"
|
||||
model = resolve_model(llama_model)
|
||||
if model is None:
|
||||
log.error(f"Could not resolve model {llama_model}")
|
||||
return request.messages
|
||||
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
log.error(f"Unsupported inference model? {model.descriptor()}")
|
||||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
):
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
if fmt_prompt := response_format_prompt(request.response_format):
|
||||
messages.append(UserMessage(content=fmt_prompt))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: ResponseFormat | None):
|
||||
if not fmt:
|
||||
return None
|
||||
|
|
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
|
|||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
|
||||
sys_content = ""
|
||||
|
||||
tool_template = None
|
||||
if request.tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(request.tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
sys_content += default_template.render()
|
||||
|
||||
if existing_system_message:
|
||||
# TODO: this fn is needed in many places
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
sys_content += "\n"
|
||||
|
||||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
||||
if fmt == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif fmt == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
||||
|
||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
messages.append(UserMessage(content=custom_template.render()))
|
||||
|
||||
# Add back existing messages from the request
|
||||
messages += existing_messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
sys_content = ""
|
||||
custom_tools, builtin_tools = [], []
|
||||
for t in request.tools:
|
||||
if isinstance(t.tool_name, str):
|
||||
custom_tools.append(t)
|
||||
else:
|
||||
builtin_tools.append(t)
|
||||
|
||||
if builtin_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||
if custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||
|
||||
system_prompt = None
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
if existing_system_message and (
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||
):
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
|
||||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileLastError,
|
||||
|
|
@ -921,22 +921,21 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
content = []
|
||||
for chunk in chunks:
|
||||
content.extend(self._chunk_to_vector_store_content(chunk))
|
||||
return VectorStoreFileContentsResponse(
|
||||
file_id=file_id,
|
||||
filename=file_info.get("filename", ""),
|
||||
attributes=file_info.get("attributes", {}),
|
||||
content=content,
|
||||
return VectorStoreFileContentResponse(
|
||||
object="vector_store.file_content.page",
|
||||
data=content,
|
||||
has_more=False,
|
||||
next_page=None,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
|
|||
|
||||
# Verify instructions from previous response was not carried over to the next response
|
||||
assert response_with_instructions2.instructions == instructions2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Tool calling is not reliable.")
|
||||
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with function tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
max_tool_calls = 1
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# First create a response that triggers function tools
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Can you tell me the weather in Paris and the current time?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Verify we got two function calls and that the max_tool_calls do not affect function tools
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "function_call"
|
||||
assert response.output[1].name == "get_time"
|
||||
assert response.output[0].status == "completed"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response.max_tool_calls == max_tool_calls
|
||||
|
||||
|
||||
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of invalid max_tool_calls in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology news."
|
||||
invalid_max_tool_calls = 0
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# Create a response with an invalid max_tool_calls value i.e. 0
|
||||
# Handle ValueError from LLS and BadRequestError from OpenAI client
|
||||
with pytest.raises((ValueError, BadRequestError)) as excinfo:
|
||||
client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=invalid_max_tool_calls,
|
||||
)
|
||||
|
||||
error_message = str(excinfo.value)
|
||||
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
|
||||
f"Expected error message about invalid max_tool_calls, got: {error_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with built-in tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
|
||||
max_tool_calls = [1, 5]
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# First create a response that triggers web_search tools without max_tool_calls
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response.output) == 3
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "web_search_call"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[2].type == "message"
|
||||
assert response.output[2].status == "completed"
|
||||
assert response.output[2].role == "assistant"
|
||||
|
||||
# Next create a response that triggers web_search tools with max_tool_calls set to 1
|
||||
response_2 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[0],
|
||||
)
|
||||
|
||||
# Verify we got one web search tool call followed by a message
|
||||
assert len(response_2.output) == 2
|
||||
assert response_2.output[0].type == "web_search_call"
|
||||
assert response_2.output[0].status == "completed"
|
||||
assert response_2.output[1].type == "message"
|
||||
assert response_2.output[1].status == "completed"
|
||||
assert response_2.output[1].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_2.max_tool_calls == max_tool_calls[0]
|
||||
|
||||
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
|
||||
response_3 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[1],
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response_3.output) == 3
|
||||
assert response_3.output[0].type == "web_search_call"
|
||||
assert response_3.output[0].status == "completed"
|
||||
assert response_3.output[1].type == "web_search_call"
|
||||
assert response_3.output[1].status == "completed"
|
||||
assert response_3.output[2].type == "message"
|
||||
assert response_3.output[2].status == "completed"
|
||||
assert response_3.output[2].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_3.max_tool_calls == max_tool_calls[1]
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
|
||||
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
|
||||
"remote::groq",
|
||||
"remote::oci",
|
||||
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
|
||||
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
|
||||
"remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
|||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::oci",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
|
|
|||
|
|
@ -907,16 +907,16 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
)
|
||||
|
||||
assert file_contents is not None
|
||||
assert len(file_contents.content) == 1
|
||||
content = file_contents.content[0]
|
||||
assert file_contents.object == "vector_store.file_content.page"
|
||||
assert len(file_contents.data) == 1
|
||||
content = file_contents.data[0]
|
||||
|
||||
# llama-stack-client returns a model, openai-python is a badboy and returns a dict
|
||||
if not isinstance(content, dict):
|
||||
content = content.model_dump()
|
||||
assert content["type"] == "text"
|
||||
assert content["text"] == test_content.decode("utf-8")
|
||||
assert file_contents.filename == file_name
|
||||
assert file_contents.attributes == attributes
|
||||
assert file_contents.has_more is False
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
|
|
@ -1483,14 +1483,12 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
)
|
||||
|
||||
assert file_contents is not None
|
||||
assert file_contents.filename == file_data[i][0]
|
||||
assert len(file_contents.content) > 0
|
||||
assert file_contents.object == "vector_store.file_content.page"
|
||||
assert len(file_contents.data) > 0
|
||||
|
||||
# Verify the content matches what we uploaded
|
||||
content_text = (
|
||||
file_contents.content[0].text
|
||||
if hasattr(file_contents.content[0], "text")
|
||||
else file_contents.content[0]["text"]
|
||||
file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"]
|
||||
)
|
||||
assert file_data[i][1].decode("utf-8") in content_text
|
||||
|
||||
|
|
|
|||
|
|
@ -1,303 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
async def test_system_default():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_builtin_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_custom_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_system_custom_and_builtin():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_completion_message_encoding():
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '[custom1(param1="value1")]' in prompt
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||
|
||||
|
||||
async def test_user_provided_system_message():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_builtin_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestModelRunner:
|
||||
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||
|
||||
def test_chat_completion_task_dispatch(self):
|
||||
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||
# Create a mock generator
|
||||
mock_generator = Mock()
|
||||
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
# Create a chat_completion task
|
||||
fake_params = {"model": "test"}
|
||||
fake_messages = [{"role": "user", "content": "test"}]
|
||||
task = ("chat_completion", [fake_params, fake_messages])
|
||||
|
||||
# Execute task
|
||||
runner(task)
|
||||
|
||||
# Verify chat_completion was called with correct arguments
|
||||
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||
|
||||
def test_invalid_task_type_raises_error(self):
|
||||
"""Verify ModelRunner rejects invalid task types."""
|
||||
mock_generator = Mock()
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||
runner(("invalid_task", []))
|
||||
|
|
@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
|
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
|||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
|||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,220 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments='{"foo": "bar"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import RawTextItem
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertOpenAIMessageToRawMessage:
|
||||
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||
|
||||
async def test_user_message_conversion(self):
|
||||
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "user"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hello world"
|
||||
|
||||
async def test_assistant_message_conversion(self):
|
||||
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "assistant"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hi there!"
|
||||
assert raw_msg.tool_calls == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue