From d4ecbfd092a7502b4b3ffffbbc3df75c8c38862d Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 10 Nov 2025 10:16:35 -0800 Subject: [PATCH 01/15] fix(vector store)!: fix file content API (#4105) # What does this PR do? - changed to match https://app.stainless.com/api/spec/documented/openai/openapi.documented.yml ## Test Plan updated test CI --- client-sdks/stainless/openapi.yml | 48 ++++++++----------- docs/static/llama-stack-spec.yaml | 48 ++++++++----------- docs/static/stainless-llama-stack-spec.yaml | 48 ++++++++----------- src/llama_stack/apis/vector_io/vector_io.py | 24 +++++----- src/llama_stack/core/routers/vector_io.py | 4 +- .../core/routing_tables/vector_stores.py | 4 +- .../utils/memory/openai_vector_store_mixin.py | 15 +++--- .../vector_io/test_openai_vector_stores.py | 16 +++---- 8 files changed, 93 insertions(+), 114 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index d8159be62..adee2f086 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -2916,11 +2916,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -10465,41 +10465,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ea7fd6eec..72600bf13 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -2913,11 +2913,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -9749,41 +9749,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index d8159be62..adee2f086 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -2916,11 +2916,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -10465,41 +10465,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 26c961db3..846c6f191 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel): @json_schema_type -class VectorStoreFileContentsResponse(BaseModel): - """Response from retrieving the contents of a vector store file. +class VectorStoreFileContentResponse(BaseModel): + """Represents the parsed content of a vector store file. - :param file_id: Unique identifier for the file - :param filename: Name of the file - :param attributes: Key-value attributes associated with the file - :param content: List of content items from the file + :param object: The object type, which is always `vector_store.file_content.page` + :param data: Parsed content of the file + :param has_more: Indicates if there are more content pages to fetch + :param next_page: The token for the next page, if any """ - file_id: str - filename: str - attributes: dict[str, Any] - content: list[VectorStoreContent] + object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page" + data: list[VectorStoreContent] + has_more: bool + next_page: str | None = None @json_schema_type @@ -732,12 +732,12 @@ class VectorIO(Protocol): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file. :param vector_store_id: The ID of the vector store containing the file to retrieve. :param file_id: The ID of the file to retrieve. - :returns: A list of InterleavedContent representing the file contents. + :returns: A VectorStoreFileContentResponse representing the file contents. """ ... diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index b54217619..9dac461db 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategyStaticConfig, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFilesListInBatchResponse, @@ -338,7 +338,7 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") provider = await self.routing_table.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( diff --git a/src/llama_stack/core/routing_tables/vector_stores.py b/src/llama_stack/core/routing_tables/vector_stores.py index c6c80a01e..f95a4dbe3 100644 --- a/src/llama_stack/core/routing_tables/vector_stores.py +++ b/src/llama_stack/core/routing_tables/vector_stores.py @@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import ( SearchRankingOptions, VectorStoreChunkingStrategy, VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFileStatus, @@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: await self.assert_action_allowed("read", "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index d047d9d12..86e6ea013 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( VectorStoreContent, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileCounts, VectorStoreFileDeleteResponse, VectorStoreFileLastError, @@ -921,22 +921,21 @@ class OpenAIVectorStoreMixin(ABC): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file.""" if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) - file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] content = [] for chunk in chunks: content.extend(self._chunk_to_vector_store_content(chunk)) - return VectorStoreFileContentsResponse( - file_id=file_id, - filename=file_info.get("filename", ""), - attributes=file_info.get("attributes", {}), - content=content, + return VectorStoreFileContentResponse( + object="vector_store.file_content.page", + data=content, + has_more=False, + next_page=None, ) async def openai_update_vector_store_file( diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 97ce4abe8..20f9d2978 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -907,16 +907,16 @@ def test_openai_vector_store_retrieve_file_contents( ) assert file_contents is not None - assert len(file_contents.content) == 1 - content = file_contents.content[0] + assert file_contents.object == "vector_store.file_content.page" + assert len(file_contents.data) == 1 + content = file_contents.data[0] # llama-stack-client returns a model, openai-python is a badboy and returns a dict if not isinstance(content, dict): content = content.model_dump() assert content["type"] == "text" assert content["text"] == test_content.decode("utf-8") - assert file_contents.filename == file_name - assert file_contents.attributes == attributes + assert file_contents.has_more is False @vector_provider_wrapper @@ -1483,14 +1483,12 @@ def test_openai_vector_store_file_batch_retrieve_contents( ) assert file_contents is not None - assert file_contents.filename == file_data[i][0] - assert len(file_contents.content) > 0 + assert file_contents.object == "vector_store.file_content.page" + assert len(file_contents.data) > 0 # Verify the content matches what we uploaded content_text = ( - file_contents.content[0].text - if hasattr(file_contents.content[0], "text") - else file_contents.content[0]["text"] + file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"] ) assert file_data[i][1].decode("utf-8") in content_text From fadf17daf37c1518a5b05adf56bc0939453c0a6e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 10 Nov 2025 10:36:33 -0800 Subject: [PATCH 02/15] feat(api)!: deprecate register/unregister resource APIs (#4099) Mark all register_* / unregister_* APIs as deprecated across models, shields, tool groups, datasets, benchmarks, and scoring functions. This is the first step toward moving resource mutations to an `/admin` namespace as outlined in https://github.com/llamastack/llama-stack/issues/3809#issuecomment-3492931585. The deprecation flag will be reflected in the OpenAPI schema to warn API users that these endpoints are being phased out. Next step will be implementing the `/admin` route namespace for these resource management operations. - `register_model` / `unregister_model` - `register_shield` / `unregister_shield` - `register_tool_group` / `unregister_toolgroup` - `register_dataset` / `unregister_dataset` - `register_benchmark` / `unregister_benchmark` - `register_scoring_function` / `unregister_scoring_function` --- client-sdks/stainless/openapi.yml | 603 ++------- docs/static/deprecated-llama-stack-spec.yaml | 1094 ++++++++++++++++- .../static/experimental-llama-stack-spec.yaml | 214 ++-- docs/static/llama-stack-spec.yaml | 389 +----- docs/static/stainless-llama-stack-spec.yaml | 603 ++------- src/llama_stack/apis/benchmarks/benchmarks.py | 4 +- src/llama_stack/apis/datasets/datasets.py | 4 +- src/llama_stack/apis/models/models.py | 4 +- .../scoring_functions/scoring_functions.py | 6 +- src/llama_stack/apis/shields/shields.py | 4 +- src/llama_stack/apis/tools/tools.py | 4 +- 11 files changed, 1454 insertions(+), 1475 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index adee2f086..2b9849535 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -998,39 +998,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1065,36 +1032,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1725,32 +1662,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1782,33 +1693,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1897,36 +1781,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1958,33 +1812,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2080,32 +1907,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2137,32 +1938,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -3171,7 +2946,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -3228,7 +3003,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -3279,7 +3054,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -3336,7 +3111,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -6280,46 +6055,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -6377,6 +6112,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -9115,61 +8859,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -9345,35 +9034,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -9634,37 +9294,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: @@ -10810,68 +10439,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -10939,47 +10506,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest BenchmarkConfig: type: object properties: @@ -11841,6 +11367,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 3bc965eb7..dea2e5bbe 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -13,7 +13,352 @@ info: migration reference only. servers: - url: http://any-hosted-llama-stack.com -paths: {} +paths: + /v1/models: + post: + responses: + '200': + description: A Model. + content: + application/json: + schema: + $ref: '#/components/schemas/Model' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: Register model. + description: >- + Register model. + + Register a model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterModelRequest' + required: true + deprecated: true + /v1/models/{model_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: Unregister model. + description: >- + Unregister model. + + Unregister a model. + parameters: + - name: model_id + in: path + description: >- + The identifier of the model to unregister. + required: true + schema: + type: string + deprecated: true + /v1/scoring-functions: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ScoringFunctions + summary: Register a scoring function. + description: Register a scoring function. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterScoringFunctionRequest' + required: true + deprecated: true + /v1/scoring-functions/{scoring_fn_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ScoringFunctions + summary: Unregister a scoring function. + description: Unregister a scoring function. + parameters: + - name: scoring_fn_id + in: path + description: >- + The ID of the scoring function to unregister. + required: true + schema: + type: string + deprecated: true + /v1/shields: + post: + responses: + '200': + description: A Shield. + content: + application/json: + schema: + $ref: '#/components/schemas/Shield' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + summary: Register a shield. + description: Register a shield. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterShieldRequest' + required: true + deprecated: true + /v1/shields/{identifier}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + summary: Unregister a shield. + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string + deprecated: true + /v1/toolgroups: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolGroups + summary: Register a tool group. + description: Register a tool group. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterToolGroupRequest' + required: true + deprecated: true + /v1/toolgroups/{toolgroup_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolGroups + summary: Unregister a tool group. + description: Unregister a tool group. + parameters: + - name: toolgroup_id + in: path + description: The ID of the tool group to unregister. + required: true + schema: + type: string + deprecated: true + /v1beta/datasets: + post: + responses: + '200': + description: A Dataset. + content: + application/json: + schema: + $ref: '#/components/schemas/Dataset' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Datasets + summary: Register a new dataset. + description: Register a new dataset. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterDatasetRequest' + required: true + deprecated: true + /v1beta/datasets/{dataset_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Datasets + summary: Unregister a dataset by its ID. + description: Unregister a dataset by its ID. + parameters: + - name: dataset_id + in: path + description: The ID of the dataset to unregister. + required: true + schema: + type: string + deprecated: true + /v1alpha/eval/benchmarks: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Benchmarks + summary: Register a benchmark. + description: Register a benchmark. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterBenchmarkRequest' + required: true + deprecated: true + /v1alpha/eval/benchmarks/{benchmark_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Benchmarks + summary: Unregister a benchmark. + description: Unregister a benchmark. + parameters: + - name: benchmark_id + in: path + description: The ID of the benchmark to unregister. + required: true + schema: + type: string + deprecated: true jsonSchemaDialect: >- https://json-schema.org/draft/2020-12/schema components: @@ -46,6 +391,730 @@ components: title: Error description: >- Error response from the API. Roughly follows RFC 7807. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. + RegisterModelRequest: + type: object + properties: + model_id: + type: string + description: The identifier of the model to register. + provider_model_id: + type: string + description: >- + The identifier of the model in the provider. + provider_id: + type: string + description: The identifier of the provider. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Any additional metadata for this model. + model_type: + $ref: '#/components/schemas/ModelType' + description: The type of model to register. + additionalProperties: false + required: + - model_id + title: RegisterModelRequest + Model: + type: object + properties: + identifier: + type: string + description: >- + Unique identifier for this resource in llama stack + provider_resource_id: + type: string + description: >- + Unique identifier for this resource in the provider + provider_id: + type: string + description: >- + ID of the provider that owns this resource + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: model + default: model + description: >- + The resource type, always 'model' for model resources + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Any additional metadata for this model + model_type: + $ref: '#/components/schemas/ModelType' + default: llm + description: >- + The type of model (LLM or embedding model) + additionalProperties: false + required: + - identifier + - provider_id + - type + - metadata + - model_type + title: Model + description: >- + A model resource representing an AI model registered in Llama Stack. + AggregationFunctionType: + type: string + enum: + - average + - weighted_average + - median + - categorical_count + - accuracy + title: AggregationFunctionType + description: >- + Types of aggregation functions for scoring results. + ArrayType: + type: object + properties: + type: + type: string + const: array + default: array + description: Discriminator type. Always "array" + additionalProperties: false + required: + - type + title: ArrayType + description: Parameter type for array values. + BasicScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: basic + default: basic + description: >- + The type of scoring function parameters, always basic + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - aggregation_functions + title: BasicScoringFnParams + description: >- + Parameters for basic scoring function configuration. + BooleanType: + type: object + properties: + type: + type: string + const: boolean + default: boolean + description: Discriminator type. Always "boolean" + additionalProperties: false + required: + - type + title: BooleanType + description: Parameter type for boolean values. + ChatCompletionInputType: + type: object + properties: + type: + type: string + const: chat_completion_input + default: chat_completion_input + description: >- + Discriminator type. Always "chat_completion_input" + additionalProperties: false + required: + - type + title: ChatCompletionInputType + description: >- + Parameter type for chat completion input. + CompletionInputType: + type: object + properties: + type: + type: string + const: completion_input + default: completion_input + description: >- + Discriminator type. Always "completion_input" + additionalProperties: false + required: + - type + title: CompletionInputType + description: Parameter type for completion input. + JsonType: + type: object + properties: + type: + type: string + const: json + default: json + description: Discriminator type. Always "json" + additionalProperties: false + required: + - type + title: JsonType + description: Parameter type for JSON values. + LLMAsJudgeScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: llm_as_judge + default: llm_as_judge + description: >- + The type of scoring function parameters, always llm_as_judge + judge_model: + type: string + description: >- + Identifier of the LLM model to use as a judge for scoring + prompt_template: + type: string + description: >- + (Optional) Custom prompt template for the judge model + judge_score_regexes: + type: array + items: + type: string + description: >- + Regexes to extract the answer from generated response + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - judge_model + - judge_score_regexes + - aggregation_functions + title: LLMAsJudgeScoringFnParams + description: >- + Parameters for LLM-as-judge scoring function configuration. + NumberType: + type: object + properties: + type: + type: string + const: number + default: number + description: Discriminator type. Always "number" + additionalProperties: false + required: + - type + title: NumberType + description: Parameter type for numeric values. + ObjectType: + type: object + properties: + type: + type: string + const: object + default: object + description: Discriminator type. Always "object" + additionalProperties: false + required: + - type + title: ObjectType + description: Parameter type for object values. + ParamType: + oneOf: + - $ref: '#/components/schemas/StringType' + - $ref: '#/components/schemas/NumberType' + - $ref: '#/components/schemas/BooleanType' + - $ref: '#/components/schemas/ArrayType' + - $ref: '#/components/schemas/ObjectType' + - $ref: '#/components/schemas/JsonType' + - $ref: '#/components/schemas/UnionType' + - $ref: '#/components/schemas/ChatCompletionInputType' + - $ref: '#/components/schemas/CompletionInputType' + discriminator: + propertyName: type + mapping: + string: '#/components/schemas/StringType' + number: '#/components/schemas/NumberType' + boolean: '#/components/schemas/BooleanType' + array: '#/components/schemas/ArrayType' + object: '#/components/schemas/ObjectType' + json: '#/components/schemas/JsonType' + union: '#/components/schemas/UnionType' + chat_completion_input: '#/components/schemas/ChatCompletionInputType' + completion_input: '#/components/schemas/CompletionInputType' + RegexParserScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: regex_parser + default: regex_parser + description: >- + The type of scoring function parameters, always regex_parser + parsing_regexes: + type: array + items: + type: string + description: >- + Regex to extract the answer from generated response + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - parsing_regexes + - aggregation_functions + title: RegexParserScoringFnParams + description: >- + Parameters for regex parser scoring function configuration. + ScoringFnParams: + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' + discriminator: + propertyName: type + mapping: + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + basic: '#/components/schemas/BasicScoringFnParams' + ScoringFnParamsType: + type: string + enum: + - llm_as_judge + - regex_parser + - basic + title: ScoringFnParamsType + description: >- + Types of scoring function parameter configurations. + StringType: + type: object + properties: + type: + type: string + const: string + default: string + description: Discriminator type. Always "string" + additionalProperties: false + required: + - type + title: StringType + description: Parameter type for string values. + UnionType: + type: object + properties: + type: + type: string + const: union + default: union + description: Discriminator type. Always "union" + additionalProperties: false + required: + - type + title: UnionType + description: Parameter type for union values. + RegisterScoringFunctionRequest: + type: object + properties: + scoring_fn_id: + type: string + description: >- + The ID of the scoring function to register. + description: + type: string + description: The description of the scoring function. + return_type: + $ref: '#/components/schemas/ParamType' + description: The return type of the scoring function. + provider_scoring_fn_id: + type: string + description: >- + The ID of the provider scoring function to use for the scoring function. + provider_id: + type: string + description: >- + The ID of the provider to use for the scoring function. + params: + $ref: '#/components/schemas/ScoringFnParams' + description: >- + The parameters for the scoring function for benchmark eval, these can + be overridden for app eval. + additionalProperties: false + required: + - scoring_fn_id + - description + - return_type + title: RegisterScoringFunctionRequest + RegisterShieldRequest: + type: object + properties: + shield_id: + type: string + description: >- + The identifier of the shield to register. + provider_shield_id: + type: string + description: >- + The identifier of the shield in the provider. + provider_id: + type: string + description: The identifier of the provider. + params: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The parameters of the shield. + additionalProperties: false + required: + - shield_id + title: RegisterShieldRequest + Shield: + type: object + properties: + identifier: + type: string + provider_resource_id: + type: string + provider_id: + type: string + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: shield + default: shield + description: The resource type, always shield + params: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Configuration parameters for the shield + additionalProperties: false + required: + - identifier + - provider_id + - type + title: Shield + description: >- + A safety shield resource that can be used to check content. + URL: + type: object + properties: + uri: + type: string + description: The URL string pointing to the resource + additionalProperties: false + required: + - uri + title: URL + description: A URL reference to external content. + RegisterToolGroupRequest: + type: object + properties: + toolgroup_id: + type: string + description: The ID of the tool group to register. + provider_id: + type: string + description: >- + The ID of the provider to use for the tool group. + mcp_endpoint: + $ref: '#/components/schemas/URL' + description: >- + The MCP endpoint to use for the tool group. + args: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + A dictionary of arguments to pass to the tool group. + additionalProperties: false + required: + - toolgroup_id + - provider_id + title: RegisterToolGroupRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RowsDataSource: + type: object + properties: + type: + type: string + const: rows + default: rows + rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user", + "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, + world!"}]} ] + additionalProperties: false + required: + - type + - rows + title: RowsDataSource + description: A dataset stored in rows. + URIDataSource: + type: object + properties: + type: + type: string + const: uri + default: uri + uri: + type: string + description: >- + The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}" + additionalProperties: false + required: + - type + - uri + title: URIDataSource + description: >- + A dataset that can be obtained from a URI. + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + Dataset: + type: object + properties: + identifier: + type: string + provider_resource_id: + type: string + provider_id: + type: string + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: dataset + default: dataset + description: >- + Type of resource, always 'dataset' for datasets + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + Purpose of the dataset indicating its intended use + source: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + description: >- + Data source configuration for the dataset + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Additional metadata for the dataset + additionalProperties: false + required: + - identifier + - provider_id + - type + - purpose + - source + - metadata + title: Dataset + description: >- + Dataset resource for storing and accessing training or evaluation data. + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed @@ -93,4 +1162,25 @@ components: detail: An unexpected error occurred security: - Default: [] -tags: [] +tags: + - name: Benchmarks + description: '' + - name: Datasets + description: '' + - name: Models + description: '' + - name: ScoringFunctions + description: '' + - name: Shields + description: '' + - name: ToolGroups + description: '' +x-tagGroups: + - name: Operations + tags: + - Benchmarks + - Datasets + - Models + - ScoringFunctions + - Shields + - ToolGroups diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 68e2f59be..6f379d17c 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -162,7 +162,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -219,7 +219,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -270,7 +270,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -327,7 +327,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -936,68 +936,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -1065,47 +1003,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest AggregationFunctionType: type: string enum: @@ -2254,6 +2151,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 72600bf13..4680afac9 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -995,39 +995,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1062,36 +1029,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1722,32 +1659,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1779,33 +1690,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1894,36 +1778,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1955,33 +1809,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2077,32 +1904,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2134,32 +1935,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -5564,46 +5339,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -5661,6 +5396,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -8399,61 +8143,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -8629,35 +8318,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -8918,37 +8578,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index adee2f086..2b9849535 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -998,39 +998,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1065,36 +1032,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1725,32 +1662,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1782,33 +1693,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1897,36 +1781,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1958,33 +1812,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2080,32 +1907,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2137,32 +1938,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -3171,7 +2946,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -3228,7 +3003,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -3279,7 +3054,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -3336,7 +3111,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -6280,46 +6055,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -6377,6 +6112,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -9115,61 +8859,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -9345,35 +9034,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -9634,37 +9294,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: @@ -10810,68 +10439,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -10939,47 +10506,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest BenchmarkConfig: type: object properties: @@ -11841,6 +11367,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/src/llama_stack/apis/benchmarks/benchmarks.py b/src/llama_stack/apis/benchmarks/benchmarks.py index 933205489..9a67269c3 100644 --- a/src/llama_stack/apis/benchmarks/benchmarks.py +++ b/src/llama_stack/apis/benchmarks/benchmarks.py @@ -74,7 +74,7 @@ class Benchmarks(Protocol): """ ... - @webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA) + @webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True) async def register_benchmark( self, benchmark_id: str, @@ -95,7 +95,7 @@ class Benchmarks(Protocol): """ ... - @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA) + @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True) async def unregister_benchmark(self, benchmark_id: str) -> None: """Unregister a benchmark. diff --git a/src/llama_stack/apis/datasets/datasets.py b/src/llama_stack/apis/datasets/datasets.py index ed4ecec22..9bedc6209 100644 --- a/src/llama_stack/apis/datasets/datasets.py +++ b/src/llama_stack/apis/datasets/datasets.py @@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel): class Datasets(Protocol): - @webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA) + @webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True) async def register_dataset( self, purpose: DatasetPurpose, @@ -235,7 +235,7 @@ class Datasets(Protocol): """ ... - @webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA) + @webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True) async def unregister_dataset( self, dataset_id: str, diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index 5c976886c..bbb359b51 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -136,7 +136,7 @@ class Models(Protocol): """ ... - @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_model( self, model_id: str, @@ -158,7 +158,7 @@ class Models(Protocol): """ ... - @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_model( self, model_id: str, diff --git a/src/llama_stack/apis/scoring_functions/scoring_functions.py b/src/llama_stack/apis/scoring_functions/scoring_functions.py index fe49723ab..78f4a7541 100644 --- a/src/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/src/llama_stack/apis/scoring_functions/scoring_functions.py @@ -178,7 +178,7 @@ class ScoringFunctions(Protocol): """ ... - @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_scoring_function( self, scoring_fn_id: str, @@ -199,7 +199,9 @@ class ScoringFunctions(Protocol): """ ... - @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod( + route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True + ) async def unregister_scoring_function(self, scoring_fn_id: str) -> None: """Unregister a scoring function. diff --git a/src/llama_stack/apis/shields/shields.py b/src/llama_stack/apis/shields/shields.py index ca4483828..659ba8b75 100644 --- a/src/llama_stack/apis/shields/shields.py +++ b/src/llama_stack/apis/shields/shields.py @@ -67,7 +67,7 @@ class Shields(Protocol): """ ... - @webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_shield( self, shield_id: str, @@ -85,7 +85,7 @@ class Shields(Protocol): """ ... - @webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_shield(self, identifier: str) -> None: """Unregister a shield. diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index c9bdfcfb6..4e7cf2544 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel): @runtime_checkable @telemetry_traceable class ToolGroups(Protocol): - @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_tool_group( self, toolgroup_id: str, @@ -167,7 +167,7 @@ class ToolGroups(Protocol): """ ... - @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_toolgroup( self, toolgroup_id: str, From 209a78b618f5e71b1ff384ba9877c815950ac8e1 Mon Sep 17 00:00:00 2001 From: Dennis Kennetz Date: Mon, 10 Nov 2025 15:16:24 -0600 Subject: [PATCH 03/15] feat: add oci genai service as chat inference provider (#3876) # What does this PR do? Adds OCI GenAI PaaS models for openai chat completion endpoints. ## Test Plan In an OCI tenancy with access to GenAI PaaS, perform the following steps: 1. Ensure you have IAM policies in place to use service (check docs included in this PR) 2. For local development, [setup OCI cli](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliinstall.htm) and configure the CLI with your region, tenancy, and auth [here](https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/cliconfigure.htm) 3. Once configured, go through llama-stack setup and run llama-stack (uses config based auth) like: ```bash OCI_AUTH_TYPE=config_file \ OCI_CLI_PROFILE=CHICAGO \ OCI_REGION=us-chicago-1 \ OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..aaaaaaaa5...5a \ llama stack run oci ``` 4. Hit the `models` endpoint to list models after server is running: ```bash curl http://localhost:8321/v1/models | jq ... { "identifier": "meta.llama-4-scout-17b-16e-instruct", "provider_resource_id": "ocid1.generativeaimodel.oc1.us-chicago-1.am...q", "provider_id": "oci", "type": "model", "metadata": { "display_name": "meta.llama-4-scout-17b-16e-instruct", "capabilities": [ "CHAT" ], "oci_model_id": "ocid1.generativeaimodel.oc1.us-chicago-1.a...q" }, "model_type": "llm" }, ... ``` 5. Use the "display_name" field to use the model in a `/chat/completions` request: ```bash # Streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": true, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' # Non-streaming result curl -X POST http://localhost:8321/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "meta.llama-4-scout-17b-16e-instruct", "stream": false, "temperature": 0.9, "messages": [ { "role": "system", "content": "You are a funny comedian. You can be crass." }, { "role": "user", "content": "Tell me a funny joke about programming." } ] }' ``` 6. Try out other models from the `/models` endpoint. --- .../distributions/remote_hosted_distro/oci.md | 143 ++++++++++++++++++ docs/docs/providers/inference/remote_oci.mdx | 41 +++++ pyproject.toml | 1 + src/llama_stack/distributions/oci/__init__.py | 7 + src/llama_stack/distributions/oci/build.yaml | 35 +++++ .../distributions/oci/doc_template.md | 140 +++++++++++++++++ src/llama_stack/distributions/oci/oci.py | 108 +++++++++++++ src/llama_stack/distributions/oci/run.yaml | 136 +++++++++++++++++ .../providers/registry/inference.py | 14 ++ .../remote/inference/oci/__init__.py | 17 +++ .../providers/remote/inference/oci/auth.py | 79 ++++++++++ .../providers/remote/inference/oci/config.py | 75 +++++++++ .../providers/remote/inference/oci/oci.py | 140 +++++++++++++++++ .../inference/test_openai_completion.py | 1 + .../inference/test_openai_embeddings.py | 1 + 15 files changed, 938 insertions(+) create mode 100644 docs/docs/distributions/remote_hosted_distro/oci.md create mode 100644 docs/docs/providers/inference/remote_oci.mdx create mode 100644 src/llama_stack/distributions/oci/__init__.py create mode 100644 src/llama_stack/distributions/oci/build.yaml create mode 100644 src/llama_stack/distributions/oci/doc_template.md create mode 100644 src/llama_stack/distributions/oci/oci.py create mode 100644 src/llama_stack/distributions/oci/run.yaml create mode 100644 src/llama_stack/providers/remote/inference/oci/__init__.py create mode 100644 src/llama_stack/providers/remote/inference/oci/auth.py create mode 100644 src/llama_stack/providers/remote/inference/oci/config.py create mode 100644 src/llama_stack/providers/remote/inference/oci/oci.py diff --git a/docs/docs/distributions/remote_hosted_distro/oci.md b/docs/docs/distributions/remote_hosted_distro/oci.md new file mode 100644 index 000000000..b13cf5f73 --- /dev/null +++ b/docs/docs/distributions/remote_hosted_distro/oci.md @@ -0,0 +1,143 @@ +--- +orphan: true +--- + +# OCI Distribution + +The `llamastack/distribution-oci` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| files | `inline::localfs` | +| inference | `remote::oci` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` | +| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | + + +### Environment Variables + +The following environment variables can be configured: + +- `OCI_AUTH_TYPE`: OCI authentication type (instance_principal or config_file) (default: `instance_principal`) +- `OCI_REGION`: OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1) (default: ``) +- `OCI_COMPARTMENT_OCID`: OCI compartment ID for the Generative AI service (default: ``) +- `OCI_CONFIG_FILE_PATH`: OCI config file path (required if OCI_AUTH_TYPE is config_file) (default: `~/.oci/config`) +- `OCI_CLI_PROFILE`: OCI CLI profile name to use from config file (default: `DEFAULT`) + + +## Prerequisites +### Oracle Cloud Infrastructure Setup + +Before using the OCI Generative AI distribution, ensure you have: + +1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/) +2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy +3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models +4. **Authentication**: Configure authentication using either: + - **Instance Principal** (recommended for cloud-hosted deployments) + - **API Key** (for on-premises or development environments) + +### Authentication Methods + +#### Instance Principal Authentication (Recommended) +Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments. + +Requirements: +- Instance must be running in an Oracle Cloud Infrastructure compartment +- Instance must have appropriate IAM policies to access Generative AI services + +#### API Key Authentication +For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file. + +### Required IAM Policies + +Ensure your OCI user or instance has the following policy statements: + +``` +Allow group to use generative-ai-inference-endpoints in compartment +Allow group to manage generative-ai-inference-endpoints in compartment +``` + +## Supported Services + +### Inference: OCI Generative AI +Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports: + +- **Chat Completions**: Conversational AI with context awareness +- **Text Generation**: Complete prompts and generate text content + +#### Available Models +Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models. + +### Safety: Llama Guard +For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide: +- Content filtering and moderation +- Policy compliance checking +- Harmful content detection + +### Vector Storage: Multiple Options +The distribution supports several vector storage providers: +- **FAISS**: Local in-memory vector search +- **ChromaDB**: Distributed vector database +- **PGVector**: PostgreSQL with vector extensions + +### Additional Services +- **Dataset I/O**: Local filesystem and Hugging Face integration +- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities +- **Evaluation**: Meta reference evaluation framework + +## Running Llama Stack with OCI + +You can run the OCI distribution via Docker or local virtual environment. + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci +``` + +### Configuration Examples + +#### Using Instance Principal (Recommended for Production) +```bash +export OCI_AUTH_TYPE=instance_principal +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1.. +``` + +#### Using API Key Authentication (Development) +```bash +export OCI_AUTH_TYPE=config_file +export OCI_CONFIG_FILE_PATH=~/.oci/config +export OCI_CLI_PROFILE=DEFAULT +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id +``` + +## Regional Endpoints + +OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit: + +https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions + +## Troubleshooting + +### Common Issues + +1. **Authentication Errors**: Verify your OCI credentials and IAM policies +2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region +3. **Permission Denied**: Check compartment permissions and Generative AI service access +4. **Region Unavailable**: Verify the specified region supports Generative AI services + +### Getting Help + +For additional support: +- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) +- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues) diff --git a/docs/docs/providers/inference/remote_oci.mdx b/docs/docs/providers/inference/remote_oci.mdx new file mode 100644 index 000000000..33a201a55 --- /dev/null +++ b/docs/docs/providers/inference/remote_oci.mdx @@ -0,0 +1,41 @@ +--- +description: | + Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. + Provider documentation + https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm +sidebar_label: Remote - Oci +title: remote::oci +--- + +# remote::oci + +## Description + + +Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. +Provider documentation +https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | +| `oci_auth_type` | `` | No | instance_principal | OCI authentication type (must be one of: instance_principal, config_file) | +| `oci_region` | `` | No | us-ashburn-1 | OCI region (e.g., us-ashburn-1) | +| `oci_compartment_id` | `` | No | | OCI compartment ID for the Generative AI service | +| `oci_config_file_path` | `` | No | ~/.oci/config | OCI config file path (required if oci_auth_type is config_file) | +| `oci_config_profile` | `` | No | DEFAULT | OCI config profile (required if oci_auth_type is config_file) | + +## Sample Configuration + +```yaml +oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal} +oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config} +oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT} +oci_region: ${env.OCI_REGION:=us-ashburn-1} +oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=} +``` diff --git a/pyproject.toml b/pyproject.toml index 4ec83249c..653c6d613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -298,6 +298,7 @@ exclude = [ "^src/llama_stack/providers/remote/agents/sample/", "^src/llama_stack/providers/remote/datasetio/huggingface/", "^src/llama_stack/providers/remote/datasetio/nvidia/", + "^src/llama_stack/providers/remote/inference/oci/", "^src/llama_stack/providers/remote/inference/bedrock/", "^src/llama_stack/providers/remote/inference/nvidia/", "^src/llama_stack/providers/remote/inference/passthrough/", diff --git a/src/llama_stack/distributions/oci/__init__.py b/src/llama_stack/distributions/oci/__init__.py new file mode 100644 index 000000000..68c0efe44 --- /dev/null +++ b/src/llama_stack/distributions/oci/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .oci import get_distribution_template # noqa: F401 diff --git a/src/llama_stack/distributions/oci/build.yaml b/src/llama_stack/distributions/oci/build.yaml new file mode 100644 index 000000000..7e082e1f6 --- /dev/null +++ b/src/llama_stack/distributions/oci/build.yaml @@ -0,0 +1,35 @@ +version: 2 +distribution_spec: + description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM + inference with scalable cloud services + providers: + inference: + - provider_type: remote::oci + vector_io: + - provider_type: inline::faiss + - provider_type: remote::chromadb + - provider_type: remote::pgvector + safety: + - provider_type: inline::llama-guard + agents: + - provider_type: inline::meta-reference + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs +image_type: venv +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/src/llama_stack/distributions/oci/doc_template.md b/src/llama_stack/distributions/oci/doc_template.md new file mode 100644 index 000000000..320530ccd --- /dev/null +++ b/src/llama_stack/distributions/oci/doc_template.md @@ -0,0 +1,140 @@ +--- +orphan: true +--- +# OCI Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} {{ model.doc_string }}` +{% endfor %} +{% endif %} + +## Prerequisites +### Oracle Cloud Infrastructure Setup + +Before using the OCI Generative AI distribution, ensure you have: + +1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/) +2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy +3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models +4. **Authentication**: Configure authentication using either: + - **Instance Principal** (recommended for cloud-hosted deployments) + - **API Key** (for on-premises or development environments) + +### Authentication Methods + +#### Instance Principal Authentication (Recommended) +Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments. + +Requirements: +- Instance must be running in an Oracle Cloud Infrastructure compartment +- Instance must have appropriate IAM policies to access Generative AI services + +#### API Key Authentication +For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file. + +### Required IAM Policies + +Ensure your OCI user or instance has the following policy statements: + +``` +Allow group to use generative-ai-inference-endpoints in compartment +Allow group to manage generative-ai-inference-endpoints in compartment +``` + +## Supported Services + +### Inference: OCI Generative AI +Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports: + +- **Chat Completions**: Conversational AI with context awareness +- **Text Generation**: Complete prompts and generate text content + +#### Available Models +Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models. + +### Safety: Llama Guard +For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide: +- Content filtering and moderation +- Policy compliance checking +- Harmful content detection + +### Vector Storage: Multiple Options +The distribution supports several vector storage providers: +- **FAISS**: Local in-memory vector search +- **ChromaDB**: Distributed vector database +- **PGVector**: PostgreSQL with vector extensions + +### Additional Services +- **Dataset I/O**: Local filesystem and Hugging Face integration +- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities +- **Evaluation**: Meta reference evaluation framework + +## Running Llama Stack with OCI + +You can run the OCI distribution via Docker or local virtual environment. + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci +``` + +### Configuration Examples + +#### Using Instance Principal (Recommended for Production) +```bash +export OCI_AUTH_TYPE=instance_principal +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1.. +``` + +#### Using API Key Authentication (Development) +```bash +export OCI_AUTH_TYPE=config_file +export OCI_CONFIG_FILE_PATH=~/.oci/config +export OCI_CLI_PROFILE=DEFAULT +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id +``` + +## Regional Endpoints + +OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit: + +https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions + +## Troubleshooting + +### Common Issues + +1. **Authentication Errors**: Verify your OCI credentials and IAM policies +2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region +3. **Permission Denied**: Check compartment permissions and Generative AI service access +4. **Region Unavailable**: Verify the specified region supports Generative AI services + +### Getting Help + +For additional support: +- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) +- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues) \ No newline at end of file diff --git a/src/llama_stack/distributions/oci/oci.py b/src/llama_stack/distributions/oci/oci.py new file mode 100644 index 000000000..1f21840f1 --- /dev/null +++ b/src/llama_stack/distributions/oci/oci.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings +from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.remote.inference.oci.config import OCIConfig + + +def get_distribution_template(name: str = "oci") -> DistributionTemplate: + providers = { + "inference": [BuildProvider(provider_type="remote::oci")], + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), + ], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], + "datasetio": [ + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), + ], + "scoring": [ + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), + ], + "tool_runtime": [ + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), + ], + "files": [BuildProvider(provider_type="inline::localfs")], + } + + inference_provider = Provider( + provider_id="oci", + provider_type="remote::oci", + config=OCIConfig.sample_run_config(), + ) + + vector_io_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + + files_provider = Provider( + provider_id="meta-reference-files", + provider_type="inline::localfs", + config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ] + + return DistributionTemplate( + name=name, + distro_type="remote_hosted", + description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + "vector_io": [vector_io_provider], + "files": [files_provider], + }, + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "OCI_AUTH_TYPE": ( + "instance_principal", + "OCI authentication type (instance_principal or config_file)", + ), + "OCI_REGION": ( + "", + "OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)", + ), + "OCI_COMPARTMENT_OCID": ( + "", + "OCI compartment ID for the Generative AI service", + ), + "OCI_CONFIG_FILE_PATH": ( + "~/.oci/config", + "OCI config file path (required if OCI_AUTH_TYPE is config_file)", + ), + "OCI_CLI_PROFILE": ( + "DEFAULT", + "OCI CLI profile name to use from config file", + ), + }, + ) diff --git a/src/llama_stack/distributions/oci/run.yaml b/src/llama_stack/distributions/oci/run.yaml new file mode 100644 index 000000000..e385ec606 --- /dev/null +++ b/src/llama_stack/distributions/oci/run.yaml @@ -0,0 +1,136 @@ +version: 2 +image_name: oci +apis: +- agents +- datasetio +- eval +- files +- inference +- safety +- scoring +- tool_runtime +- vector_io +providers: + inference: + - provider_id: oci + provider_type: remote::oci + config: + oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal} + oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config} + oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT} + oci_region: ${env.OCI_REGION:=us-ashburn-1} + oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + persistence: + namespace: vector_io::faiss + backend: kv_default + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence: + agent_state: + namespace: agents + backend: kv_default + responses: + table_name: responses + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + namespace: eval + backend: kv_default + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + namespace: datasetio::huggingface + backend: kv_default + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + namespace: datasetio::localfs + backend: kv_default + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files} + metadata_store: + table_name: files_metadata + backend: sql_default +storage: + backends: + kv_default: + type: kv_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db + sql_default: + type: sql_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + conversations: + table_name: openai_conversations + backend: sql_default + prompts: + namespace: prompts + backend: kv_default +registered_resources: + models: [] + shields: [] + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search +server: + port: 8321 +telemetry: + enabled: true diff --git a/src/llama_stack/providers/registry/inference.py b/src/llama_stack/providers/registry/inference.py index 1b70182fc..3cbfd408b 100644 --- a/src/llama_stack/providers/registry/inference.py +++ b/src/llama_stack/providers/registry/inference.py @@ -297,6 +297,20 @@ Available Models: Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview +""", + ), + RemoteProviderSpec( + api=Api.inference, + provider_type="remote::oci", + adapter_type="oci", + pip_packages=["oci"], + module="llama_stack.providers.remote.inference.oci", + config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig", + provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator", + description=""" +Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. +Provider documentation +https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm """, ), ] diff --git a/src/llama_stack/providers/remote/inference/oci/__init__.py b/src/llama_stack/providers/remote/inference/oci/__init__.py new file mode 100644 index 000000000..280a8c1d2 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import InferenceProvider + +from .config import OCIConfig + + +async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider: + from .oci import OCIInferenceAdapter + + adapter = OCIInferenceAdapter(config=config) + await adapter.initialize() + return adapter diff --git a/src/llama_stack/providers/remote/inference/oci/auth.py b/src/llama_stack/providers/remote/inference/oci/auth.py new file mode 100644 index 000000000..f64436eb5 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/auth.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Generator, Mapping +from typing import Any, override + +import httpx +import oci +import requests +from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE + +OciAuthSigner = type[oci.signer.AbstractBaseSigner] + + +class HttpxOciAuth(httpx.Auth): + """ + Custom HTTPX authentication class that implements OCI request signing. + + This class handles the authentication flow for HTTPX requests by signing them + using the OCI Signer, which adds the necessary authentication headers for + OCI API calls. + + Attributes: + signer (oci.signer.Signer): The OCI signer instance used for request signing + """ + + def __init__(self, signer: OciAuthSigner): + self.signer = signer + + @override + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + # Read the request content to handle streaming requests properly + try: + content = request.content + except httpx.RequestNotRead: + # For streaming requests, we need to read the content first + content = request.read() + + req = requests.Request( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=content, + ) + prepared_request = req.prepare() + + # Sign the request using the OCI Signer + self.signer.do_request_sign(prepared_request) # type: ignore + + # Update the original HTTPX request with the signed headers + request.headers.update(prepared_request.headers) + + yield request + + +class OciInstancePrincipalAuth(HttpxOciAuth): + def __init__(self, **kwargs: Mapping[str, Any]): + self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs) + + +class OciUserPrincipalAuth(HttpxOciAuth): + def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE): + config = oci.config.from_file(config_file, profile_name) + oci.config.validate_config(config) # type: ignore + key_content = "" + with open(config["key_file"]) as f: + key_content = f.read() + + self.signer = oci.signer.Signer( + tenancy=config["tenancy"], + user=config["user"], + fingerprint=config["fingerprint"], + private_key_file_location=config.get("key_file"), + pass_phrase="none", # type: ignore + private_key_content=key_content, + ) diff --git a/src/llama_stack/providers/remote/inference/oci/config.py b/src/llama_stack/providers/remote/inference/oci/config.py new file mode 100644 index 000000000..9747b08ea --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/config.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from llama_stack.schema_utils import json_schema_type + + +class OCIProviderDataValidator(BaseModel): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + ) + oci_region: str = Field( + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str = Field( + description="OCI compartment ID for the Generative AI service", + ) + oci_config_file_path: str | None = Field( + default="~/.oci/config", + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str | None = Field( + default="DEFAULT", + description="OCI config profile (required if oci_auth_type is config_file)", + ) + + +@json_schema_type +class OCIConfig(RemoteInferenceProviderConfig): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"), + ) + oci_region: str = Field( + default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"), + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str = Field( + default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""), + description="OCI compartment ID for the Generative AI service", + ) + oci_config_file_path: str = Field( + default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"), + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str = Field( + default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"), + description="OCI config profile (required if oci_auth_type is config_file)", + ) + + @classmethod + def sample_run_config( + cls, + oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}", + oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}", + oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}", + oci_region: str = "${env.OCI_REGION:=us-ashburn-1}", + oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}", + **kwargs, + ) -> dict[str, Any]: + return { + "oci_auth_type": oci_auth_type, + "oci_config_file_path": oci_config_file_path, + "oci_config_profile": oci_config_profile, + "oci_region": oci_region, + "oci_compartment_id": oci_compartment_id, + } diff --git a/src/llama_stack/providers/remote/inference/oci/oci.py b/src/llama_stack/providers/remote/inference/oci/oci.py new file mode 100644 index 000000000..253dcf2b6 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/oci.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from collections.abc import Iterable +from typing import Any + +import httpx +import oci +from oci.generative_ai.generative_ai_client import GenerativeAiClient +from oci.generative_ai.models import ModelCollection +from openai._base_client import DefaultAsyncHttpxClient + +from llama_stack.apis.inference.inference import ( + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) +from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger +from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth +from llama_stack.providers.remote.inference.oci.config import OCIConfig +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + +logger = get_logger(name=__name__, category="inference::oci") + +OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal" +OCI_AUTH_TYPE_CONFIG_FILE = "config_file" +VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE] +DEFAULT_OCI_REGION = "us-ashburn-1" + +MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"] + + +class OCIInferenceAdapter(OpenAIMixin): + config: OCIConfig + + async def initialize(self) -> None: + """Initialize and validate OCI configuration.""" + if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES: + raise ValueError( + f"Invalid OCI authentication type: {self.config.oci_auth_type}." + f"Valid types are one of: {VALID_OCI_AUTH_TYPES}" + ) + + if not self.config.oci_compartment_id: + raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.") + + def get_base_url(self) -> str: + region = self.config.oci_region or DEFAULT_OCI_REGION + return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1" + + def get_api_key(self) -> str | None: + # OCI doesn't use API keys, it uses request signing + return "" + + def get_extra_client_params(self) -> dict[str, Any]: + """ + Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers. + """ + auth = self._get_auth() + compartment_id = self.config.oci_compartment_id or "" + + return { + "http_client": DefaultAsyncHttpxClient( + auth=auth, + headers={ + "CompartmentId": compartment_id, + }, + ), + } + + def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + return oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + return None + + def _get_oci_config(self) -> dict: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + config = {"region": self.config.oci_region} + elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE: + config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile) + if not config.get("region"): + raise ValueError( + "Region not specified in config. Please specify in config or with OCI_REGION env variable." + ) + + return config + + def _get_auth(self) -> httpx.Auth: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + return OciInstancePrincipalAuth() + elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE: + return OciUserPrincipalAuth( + config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile + ) + else: + raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}") + + async def list_provider_model_ids(self) -> Iterable[str]: + """ + List available models from OCI Generative AI service. + """ + oci_config = self._get_oci_config() + oci_signer = self._get_oci_signer() + compartment_id = self.config.oci_compartment_id or "" + + if oci_signer is None: + client = GenerativeAiClient(config=oci_config) + else: + client = GenerativeAiClient(config=oci_config, signer=oci_signer) + + models: ModelCollection = client.list_models( + compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE" + ).data + + seen_models = set() + model_ids = [] + for model in models.items: + if model.time_deprecated or model.time_on_demand_retired: + continue + + if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities: + continue + + # Use display_name + model_type as the key to avoid conflicts + model_key = (model.display_name, ModelType.llm) + if model_key in seen_models: + continue + + seen_models.add(model_key) + model_ids.append(model.display_name) + + return model_ids + + async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse: + # The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings. + raise NotImplementedError("OCI Provider does not (currently) support embeddings") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 1568ffbe2..4ce2850b4 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos, # or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}} "remote::groq", + "remote::oci", "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 "remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported "remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 704775716..fe8070162 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): "remote::runpod", "remote::sambanova", "remote::tgi", + "remote::oci", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") From 433438cfc00f5c1aaa34f9545fc44c0c624247d0 Mon Sep 17 00:00:00 2001 From: Shabana Baig <43451943+s-akhtar-baig@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:21:27 -0500 Subject: [PATCH 04/15] feat: Implement the 'max_tool_calls' parameter for the Responses API (#4062) # Problem Responses API uses max_tool_calls parameter to limit the number of tool calls that can be generated in a response. Currently, LLS implementation of the Responses API does not support this parameter. # What does this PR do? This pull request adds the max_tool_calls field to the response object definition and updates the inline provider. it also ensures that: - the total number of calls to built-in and mcp tools do not exceed max_tool_calls - an error is thrown if max_tool_calls < 1 (behavior seen with the OpenAI Responses API, but we can change this if needed) Closes #[3563](https://github.com/llamastack/llama-stack/issues/3563) ## Test Plan - Tested manually for change in model response w.r.t supplied max_tool_calls field. - Added integration tests to test invalid max_tool_calls parameter. - Added integration tests to check max_tool_calls parameter with built-in and function tools. - Added integration tests to check max_tool_calls parameter in the returned response object. - Recorded OpenAI Responses API behavior using a sample script: https://github.com/s-akhtar-baig/llama-stack-examples/blob/main/responses/src/max_tool_calls.py Co-authored-by: Ashwin Bharambe --- client-sdks/stainless/openapi.yml | 15 ++ docs/static/llama-stack-spec.yaml | 15 ++ docs/static/stainless-llama-stack-spec.yaml | 15 ++ src/llama_stack/apis/agents/agents.py | 2 + .../apis/agents/openai_responses.py | 2 + .../inline/agents/meta_reference/agents.py | 2 + .../responses/openai_responses.py | 7 + .../meta_reference/responses/streaming.py | 18 +- .../agents/test_openai_responses.py | 166 ++++++++++++++++++ 9 files changed, 240 insertions(+), 2 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 2b9849535..58ebaa8ae 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -6626,6 +6626,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -6984,6 +6989,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7065,6 +7075,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 4680afac9..135ae910f 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5910,6 +5910,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -6268,6 +6273,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -6349,6 +6359,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 2b9849535..58ebaa8ae 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6626,6 +6626,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -6984,6 +6989,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7065,6 +7075,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at diff --git a/src/llama_stack/apis/agents/agents.py b/src/llama_stack/apis/agents/agents.py index cadef2edc..09687ef33 100644 --- a/src/llama_stack/apis/agents/agents.py +++ b/src/llama_stack/apis/agents/agents.py @@ -87,6 +87,7 @@ class Agents(Protocol): "List of guardrails to apply during response generation. Guardrails provide safety and content moderation." ), ] = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a model response. @@ -97,6 +98,7 @@ class Agents(Protocol): :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param include: (Optional) Additional fields to include in the response. :param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications. + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response. :returns: An OpenAIResponseObject. """ ... diff --git a/src/llama_stack/apis/agents/openai_responses.py b/src/llama_stack/apis/agents/openai_responses.py index a38d1cba6..16657ab32 100644 --- a/src/llama_stack/apis/agents/openai_responses.py +++ b/src/llama_stack/apis/agents/openai_responses.py @@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel): :param truncation: (Optional) Truncation strategy applied to the response :param usage: (Optional) Token usage information for the response :param instructions: (Optional) System message inserted into the model's context + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response """ created_at: int @@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel): truncation: str | None = None usage: OpenAIResponseUsage | None = None instructions: str | None = None + max_tool_calls: int | None = None @json_schema_type diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index 7141d58bc..880e0b680 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents): include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject: assert self.openai_responses_impl is not None, "OpenAI responses not initialized" result = await self.openai_responses_impl.create_openai_response( @@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents): include, max_infer_iters, guardrails, + max_tool_calls, ) return result # type: ignore[no-any-return] diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 933cfe963..ed7f959c0 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -255,6 +255,7 @@ class OpenAIResponsesImpl: include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[str | ResponseGuardrailSpec] | None = None, + max_tool_calls: int | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text @@ -270,6 +271,9 @@ class OpenAIResponsesImpl: if not conversation.startswith("conv_"): raise InvalidConversationIdError(conversation) + if max_tool_calls is not None and max_tool_calls < 1: + raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1") + stream_gen = self._create_streaming_response( input=input, conversation=conversation, @@ -282,6 +286,7 @@ class OpenAIResponsesImpl: tools=tools, max_infer_iters=max_infer_iters, guardrail_ids=guardrail_ids, + max_tool_calls=max_tool_calls, ) if stream: @@ -331,6 +336,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, max_infer_iters: int | None = 10, guardrail_ids: list[str] | None = None, + max_tool_calls: int | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # These should never be None when called from create_openai_response (which sets defaults) # but we assert here to help mypy understand the types @@ -373,6 +379,7 @@ class OpenAIResponsesImpl: safety_api=self.safety_api, guardrail_ids=guardrail_ids, instructions=instructions, + max_tool_calls=max_tool_calls, ) # Stream the response diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index ef5603420..c16bc8df3 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -115,6 +115,7 @@ class StreamingResponseOrchestrator: safety_api, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, + max_tool_calls: int | None = None, ): self.inference_api = inference_api self.ctx = ctx @@ -126,6 +127,10 @@ class StreamingResponseOrchestrator: self.safety_api = safety_api self.guardrail_ids = guardrail_ids or [] self.prompt = prompt + # System message that is inserted into the model's context + self.instructions = instructions + # Max number of total calls to built-in tools that can be processed in a response + self.max_tool_calls = max_tool_calls self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( @@ -139,8 +144,8 @@ class StreamingResponseOrchestrator: self.accumulated_usage: OpenAIResponseUsage | None = None # Track if we've sent a refusal response self.violation_detected = False - # system message that is inserted into the model's context - self.instructions = instructions + # Track total calls made to built-in tools + self.accumulated_builtin_tool_calls = 0 async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" @@ -186,6 +191,7 @@ class StreamingResponseOrchestrator: usage=self.accumulated_usage, instructions=self.instructions, prompt=self.prompt, + max_tool_calls=self.max_tool_calls, ) async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: @@ -894,6 +900,11 @@ class StreamingResponseOrchestrator: """Coordinate execution of both function and non-function tool calls.""" # Execute non-function tool calls for tool_call in non_function_tool_calls: + # Check if total calls made to built-in and mcp tools exceed max_tool_calls + if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls: + logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.") + break + # Find the item_id for this tool call matching_item_id = None for index, item_id in completion_result_data.tool_call_item_ids.items(): @@ -974,6 +985,9 @@ class StreamingResponseOrchestrator: if tool_response_message: next_turn_messages.append(tool_response_message) + # Track number of calls made to built-in and mcp tools + self.accumulated_builtin_tool_calls += 1 + # Execute function tool calls (client-side) for tool_call in function_tool_calls: # Find the item_id for this tool call from our tracking dictionary diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index d413d5201..057cee774 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode # Verify instructions from previous response was not carried over to the next response assert response_with_instructions2.instructions == instructions2 + + +@pytest.mark.skip(reason="Tool calling is not reliable.") +def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with function tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + max_tool_calls = 1 + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather information for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + { + "type": "function", + "name": "get_time", + "description": "Get current time for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + ] + + # First create a response that triggers function tools + response = client.responses.create( + model=text_model_id, + input="Can you tell me the weather in Paris and the current time?", + tools=tools, + stream=False, + max_tool_calls=max_tool_calls, + ) + + # Verify we got two function calls and that the max_tool_calls do not affect function tools + assert len(response.output) == 2 + assert response.output[0].type == "function_call" + assert response.output[0].name == "get_weather" + assert response.output[0].status == "completed" + assert response.output[1].type == "function_call" + assert response.output[1].name == "get_time" + assert response.output[0].status == "completed" + + # Verify we have a valid max_tool_calls field + assert response.max_tool_calls == max_tool_calls + + +def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id): + """Test handling of invalid max_tool_calls in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology news." + invalid_max_tool_calls = 0 + tools = [ + {"type": "web_search"}, + ] + + # Create a response with an invalid max_tool_calls value i.e. 0 + # Handle ValueError from LLS and BadRequestError from OpenAI client + with pytest.raises((ValueError, BadRequestError)) as excinfo: + client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=invalid_max_tool_calls, + ) + + error_message = str(excinfo.value) + assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, ( + f"Expected error message about invalid max_tool_calls, got: {error_message}" + ) + + +def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with built-in tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls." + max_tool_calls = [1, 5] + tools = [ + {"type": "web_search"}, + ] + + # First create a response that triggers web_search tools without max_tool_calls + response = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + ) + + # Verify we got two web search calls followed by a message + assert len(response.output) == 3 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + assert response.output[1].type == "web_search_call" + assert response.output[1].status == "completed" + assert response.output[2].type == "message" + assert response.output[2].status == "completed" + assert response.output[2].role == "assistant" + + # Next create a response that triggers web_search tools with max_tool_calls set to 1 + response_2 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[0], + ) + + # Verify we got one web search tool call followed by a message + assert len(response_2.output) == 2 + assert response_2.output[0].type == "web_search_call" + assert response_2.output[0].status == "completed" + assert response_2.output[1].type == "message" + assert response_2.output[1].status == "completed" + assert response_2.output[1].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_2.max_tool_calls == max_tool_calls[0] + + # Finally create a response that triggers web_search tools with max_tool_calls set to 5 + response_3 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[1], + ) + + # Verify we got two web search calls followed by a message + assert len(response_3.output) == 3 + assert response_3.output[0].type == "web_search_call" + assert response_3.output[0].status == "completed" + assert response_3.output[1].type == "web_search_call" + assert response_3.output[1].status == "completed" + assert response_3.output[2].type == "message" + assert response_3.output[2].status == "completed" + assert response_3.output[2].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_3.max_tool_calls == max_tool_calls[1] From 43adc23ef641d84d183a47afa8a8653fc092f9f7 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Mon, 10 Nov 2025 18:29:24 -0500 Subject: [PATCH 05/15] refactor: remove dead inference API code and clean up imports (#4093) # What does this PR do? Delete ~2,000 lines of dead code from the old bespoke inference API that was replaced by OpenAI-only API. This includes removing unused type conversion functions, dead provider methods, and event_logger.py. Clean up imports across the codebase to remove references to deleted types. This eliminates unnecessary code and dependencies, helping isolate the API package as a self-contained module. This is the last interdependency between the .api package and "exterior" packages, meaning that now every other package in llama stack imports the API, not the other way around. ## Test Plan this is a structural change, no tests needed. --------- Signed-off-by: Charlie Doern --- .../apis/inference/event_logger.py | 43 - src/llama_stack/apis/inference/inference.py | 182 +--- src/llama_stack/core/routers/safety.py | 4 +- .../models/llama/llama3/generation.py | 4 +- .../models/llama/llama3/interface.py | 7 +- .../llama3/prompt_templates/system_prompts.py | 2 +- .../models/llama/llama3/tool_utils.py | 3 +- .../llama4/prompt_templates/system_prompts.py | 2 +- .../inference/meta_reference/generators.py | 98 +- .../inference/meta_reference/inference.py | 395 ++++++- .../meta_reference/model_parallel.py | 33 +- .../meta_reference/parallel_utils.py | 14 +- .../sentence_transformers.py | 4 - .../utils/inference/litellm_openai_mixin.py | 51 - .../utils/inference/openai_compat.py | 971 +----------------- .../utils/inference/prompt_adapter.py | 287 +----- tests/unit/models/test_prompt_adapter.py | 303 ------ .../providers/inline/inference/__init__.py | 5 + .../inline/inference/test_meta_reference.py | 44 + tests/unit/providers/nvidia/test_safety.py | 27 +- .../utils/inference/test_openai_compat.py | 220 ---- .../utils/inference/test_prompt_adapter.py | 35 + 22 files changed, 593 insertions(+), 2141 deletions(-) delete mode 100644 src/llama_stack/apis/inference/event_logger.py delete mode 100644 tests/unit/models/test_prompt_adapter.py create mode 100644 tests/unit/providers/inline/inference/__init__.py create mode 100644 tests/unit/providers/inline/inference/test_meta_reference.py delete mode 100644 tests/unit/providers/utils/inference/test_openai_compat.py create mode 100644 tests/unit/providers/utils/inference/test_prompt_adapter.py diff --git a/src/llama_stack/apis/inference/event_logger.py b/src/llama_stack/apis/inference/event_logger.py deleted file mode 100644 index d97ece6d4..000000000 --- a/src/llama_stack/apis/inference/event_logger.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from termcolor import cprint - -from llama_stack.apis.inference import ( - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, -) - - -class LogEvent: - def __init__( - self, - content: str = "", - end: str = "\n", - color="white", - ): - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def print(self, flush=True): - cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) - - -class EventLogger: - async def log(self, event_generator): - async for chunk in event_generator: - if isinstance(chunk, ChatCompletionResponseStreamChunk): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") - else: - yield LogEvent("Assistant> ", color="cyan", end="") - yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index 1a865ce5f..9f04917c9 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from enum import Enum +from enum import Enum, StrEnum from typing import ( Annotated, Any, @@ -15,28 +15,18 @@ from typing import ( ) from fastapi import Body -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import TypedDict -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.responses import ( + Order, +) from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -register_schema(ToolCall) -register_schema(ToolDefinition) - -from enum import StrEnum - @json_schema_type class GreedySamplingStrategy(BaseModel): @@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel): content: InterleavedContent -@json_schema_type -class CompletionMessage(BaseModel): - """A message containing the model's (assistant) response in a chat conversation. - - :param role: Must be "assistant" to identify this as the model's response - :param content: The content of the model's response - :param stop_reason: Reason why the model stopped generating. Options are: - - `StopReason.end_of_turn`: The model finished generating the entire response. - - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - - `StopReason.out_of_tokens`: The model ran out of token budget. - :param tool_calls: List of tool calls. Each tool call is a ToolCall object. - """ - - role: Literal["assistant"] = "assistant" - content: InterleavedContent - stop_reason: StopReason - tool_calls: list[ToolCall] | None = Field(default_factory=lambda: []) - - -Message = Annotated[ - UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage, - Field(discriminator="role"), -] -register_schema(Message, name="Message") - - -@json_schema_type -class ToolResponse(BaseModel): - """Response from a tool invocation. - - :param call_id: Unique identifier for the tool call this response is for - :param tool_name: Name of the tool that was invoked - :param content: The response content from the tool - :param metadata: (Optional) Additional metadata about the tool response - """ - - call_id: str - tool_name: BuiltinTool | str - content: InterleavedContent - metadata: dict[str, Any] | None = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ChatCompletionResponseEvent(BaseModel): - """An event during chat completion generation. - - :param event_type: Type of the event - :param delta: Content generated since last event. This can be one or more tokens, or a tool call. - :param logprobs: Optional log probabilities for generated tokens - :param stop_reason: Optional reason why generation stopped, if complete - """ - - event_type: ChatCompletionResponseEventType - delta: ContentDelta - logprobs: list[TokenLogProbs] | None = None - stop_reason: StopReason | None = None - - class ResponseFormatType(StrEnum): """Types of formats for structured (guided) decoding. @@ -357,34 +279,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. - - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ - - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None - - -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. - - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ - - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. @@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum): replace = "replace" -@json_schema_type -class ToolConfig(BaseModel): - """Configuration for tool use. - - :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. - :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - :param system_message_behavior: (Optional) Config for how to override the default system prompt. - - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string - '{{function_definitions}}' to indicate where the function definitions should be inserted. - """ - - tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) - tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) - - def model_post_init(self, __context: Any) -> None: - if isinstance(self.tool_choice, str): - try: - self.tool_choice = ToolChoice[self.tool_choice] - except KeyError: - pass - - -# This is an internally used class -@json_schema_type -class ChatCompletionRequest(BaseModel): - model: str - messages: list[Message] - sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - - tools: list[ToolDefinition] | None = Field(default_factory=lambda: []) - tool_config: ToolConfig | None = Field(default_factory=ToolConfig) - - response_format: ResponseFormat | None = None - stream: bool | None = False - logprobs: LogProbConfig | None = None - - -@json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed chat completion response. - - :param event: The event containing the new content - """ - - event: ChatCompletionResponseEvent - - -@json_schema_type -class ChatCompletionResponse(MetricResponseMixin): - """Response from a chat completion request. - - :param completion_message: The complete response message - :param logprobs: Optional log probabilities for generated tokens - """ - - completion_message: CompletionMessage - logprobs: list[TokenLogProbs] | None = None - - @json_schema_type class EmbeddingsResponse(BaseModel): """Response containing generated embeddings. diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index 79eac8b46..e5ff2ada9 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -52,7 +52,7 @@ class SafetyRouter(Safety): async def run_shield( self, shield_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") diff --git a/src/llama_stack/models/llama/llama3/generation.py b/src/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..9ac215c3b 100644 --- a/src/llama_stack/models/llama/llama3/generation.py +++ b/src/llama_stack/models/llama/llama3/generation.py @@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import ( ) from termcolor import cprint +from llama_stack.models.llama.datatypes import ToolPromptFormat + from ..checkpoint import maybe_reshard_state_dict -from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage from .args import ModelArgs from .chat_format import ChatFormat, LLMInput from .model import Transformer diff --git a/src/llama_stack/models/llama/llama3/interface.py b/src/llama_stack/models/llama/llama3/interface.py index b63ba4847..89be31a55 100644 --- a/src/llama_stack/models/llama/llama3/interface.py +++ b/src/llama_stack/models/llama/llama3/interface.py @@ -15,13 +15,10 @@ from pathlib import Path from termcolor import colored +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat + from ..datatypes import ( - BuiltinTool, RawMessage, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, ) from . import template_data from .chat_format import ChatFormat diff --git a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 11a5993e9..3fbaa103e 100644 --- a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any -from llama_stack.apis.inference import ( +from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolDefinition, ) diff --git a/src/llama_stack/models/llama/llama3/tool_utils.py b/src/llama_stack/models/llama/llama3/tool_utils.py index 8c12fe680..6f919e1fa 100644 --- a/src/llama_stack/models/llama/llama3/tool_utils.py +++ b/src/llama_stack/models/llama/llama3/tool_utils.py @@ -8,8 +8,9 @@ import json import re from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat -from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from ..datatypes import RecursiveType logger = get_logger(name=__name__, category="models::llama") diff --git a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py index 1ee570933..feded9f8c 100644 --- a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -13,7 +13,7 @@ import textwrap -from llama_stack.apis.inference import ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition from llama_stack.models.llama.llama3.prompt_templates.base import ( PromptTemplate, PromptTemplateGeneratorBase, diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index cb926f529..51a2ddfad 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import math -from collections.abc import Generator from typing import Optional import torch @@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIResponseFormatJSONSchema, ResponseFormat, + ResponseFormatType, SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_types import Model, ModelFamily -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, - get_default_tool_prompt_format, -) from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig @@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams): return temperature, top_p -def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): - tool_config = request.tool_config - if tool_config is not None and tool_config.tool_prompt_format is not None: - return tool_config.tool_prompt_format - else: - return get_default_tool_prompt_format(request.model) - - class LlamaGenerator: def __init__( self, @@ -157,55 +146,56 @@ class LlamaGenerator: self.args = self.inner_generator.args self.formatter = self.inner_generator.formatter - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(first_request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), - ) - def chat_completion( self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() + request: OpenAIChatCompletionRequestWithExtraBody, + raw_messages: list, + ): + """Generate chat completion using OpenAI request format. + + Args: + request: OpenAI chat completion request + raw_messages: Pre-converted list of RawMessage objects + """ + + # Determine tool prompt format + tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json + + # Prepare sampling params + sampling_params = SamplingParams() + if request.temperature is not None or request.top_p is not None: + sampling_params.strategy = TopPSamplingStrategy( + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, + ) + if request.max_tokens: + sampling_params.max_tokens = request.max_tokens + max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) + + # Get logits processor for response format + logits_processor = None + if request.response_format: + if isinstance(request.response_format, OpenAIResponseFormatJSONSchema): + # Extract the actual schema from OpenAIJSONSchema TypedDict + schema_dict = request.response_format.json_schema.get("schema") or {} + json_schema_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, + json_schema=schema_dict, + ) + logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) + + # Generate yield from self.inner_generator.generate( - llm_inputs=[ - self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) - for request in request_batch - ], + llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(first_request.logprobs), + logprobs=False, echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), + logits_processor=logits_processor, ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 76d3fdd50..ef21132a0 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,12 +5,19 @@ # the root directory of this source tree. import asyncio +import time +import uuid from collections.abc import AsyncIterator from llama_stack.apis.inference import ( InferenceProvider, + OpenAIAssistantMessageParam, OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, + ToolChoice, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat +from llama_stack.models.llama.llama3.prompt_templates import ( + JsonCustomToolGenerator, + SystemDefaultGenerator, +) from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) +def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition: + """Convert OpenAI tool format to ToolDefinition format.""" + # OpenAI tools have function.name and function.parameters + return ToolDefinition( + tool_name=tool.function.name, + description=tool.function.description or "", + parameters=tool.function.parameters or {}, + ) + + +def _get_tool_choice_prompt(tool_choice, tools) -> str: + """Generate prompt text for tool_choice behavior.""" + if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto": + return "" + elif tool_choice == ToolChoice.required or tool_choice == "required": + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none or tool_choice == "none": + return "" + else: + # Specific tool specified + return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def _raw_content_as_str(content) -> str: + """Convert RawContent to string for system messages.""" + if isinstance(content, str): + return content + elif isinstance(content, RawTextItem): + return content.text + elif isinstance(content, list): + return "\n".join(_raw_content_as_str(c) for c in content) + else: + return "" + + +def _augment_raw_messages_for_tools_llama_3_1( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 3.1 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions first (if present) + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # For OpenAI format, all tools are custom (have string names) + tool_gen = JsonCustomToolGenerator() + tool_template = tool_gen.gen(tool_definitions) + sys_content += tool_template.render() + sys_content += "\n" + + # Add default system prompt + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + sys_content += default_template.render() + + # Add existing system message if present + if existing_system_message: + sys_content += "\n" + _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + # Create new system message + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + + return [new_system_message] + messages + + +def _augment_raw_messages_for_tools_llama_4( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions if present + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # Use python_list format for Llama 4 + tool_gen = PythonListCustomToolGeneratorLlama4() + system_prompt = None + if existing_system_message: + system_prompt = _raw_content_as_str(existing_system_message.content) + + tool_template = tool_gen.gen(tool_definitions, system_prompt) + sys_content = tool_template.render() + elif existing_system_message: + # No tools, just use existing system message + sys_content = _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + if sys_content: + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + return [new_system_message] + messages + + return messages + + +def augment_raw_messages_for_tools( + raw_messages: list[RawMessage], + params: OpenAIChatCompletionRequestWithExtraBody, + llama_model, +) -> list[RawMessage]: + """Augment raw messages with tool definitions based on model family.""" + if not params.tools: + return raw_messages + + # Determine augmentation strategy based on model family + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # Llama 3.1 and Llama 3.2 multimodal use JSON format + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + elif llama_model.model_family in ( + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ): + # Llama 3.2/3.3/4 use python_list format + return _augment_raw_messages_for_tools_llama_4( + raw_messages, + params.tools, + params.tool_choice, + ) + else: + # Default to Llama 3.1 style + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") + await self.openai_chat_completion( - model=model_id, - messages=[{"role": "user", "content": "Hi how are you?"}], - max_tokens=20, + params=OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")], + max_tokens=20, + ) ) log.info("Warmed up!") @@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") + self.check_model(params) + + # Convert OpenAI messages to RawMessages + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, + decode_assistant_message, + ) + + raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + + # Augment messages with tool definitions if tools are present + raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model) + + # Call generator's chat_completion method (works for both single-GPU and model-parallel) + if isinstance(self.generator, LlamaGenerator): + generator = self.generator.chat_completion(params, raw_messages) + else: + # Model parallel: submit task to process group + generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) + + # Check if streaming is requested + if params.stream: + return self._stream_chat_completion(generator, params) + + # Non-streaming: collect all generated text + generated_text = "" + for result_batch in generator: + for result in result_batch: + if not result.ignore_token and result.source == "output": + generated_text += result.text + + # Decode assistant message to extract tool calls and determine stop_reason + # Default to end_of_turn if generation completed normally + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # Convert tool calls to OpenAI format + openai_tool_calls = None + if decoded_message.tool_calls: + from llama_stack.apis.inference import ( + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + ) + + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Determine finish_reason based on whether tool calls are present + finish_reason = "tool_calls" if openai_tool_calls else "stop" + + # Extract content from decoded message + content = "" + if isinstance(decoded_message.content, str): + content = decoded_message.content + elif isinstance(decoded_message.content, list): + for item in decoded_message.content: + if isinstance(item, RawTextItem): + content += item.text + + # Create OpenAI response + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + return OpenAIChatCompletion( + id=response_id, + object="chat.completion", + created=created, + model=params.model, + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=content, + tool_calls=openai_tool_calls, + ), + finish_reason=finish_reason, + logprobs=None, + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=0, # TODO: calculate properly + completion_tokens=0, # TODO: calculate properly + total_tokens=0, # TODO: calculate properly + ), + ) + + async def _stream_chat_completion( + self, + generator, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream chat completion chunks as they're generated.""" + from llama_stack.apis.inference import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoiceDelta, + OpenAIChunkChoice, + ) + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + generated_text = "" + + # Yield chunks as tokens are generated + for result_batch in generator: + for result in result_batch: + if result.ignore_token or result.source != "output": + continue + + generated_text += result.text + + # Yield delta chunk with the new text + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + content=result.text, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + # After generation completes, decode the full message to extract tool calls + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # If tool calls are present, yield a final chunk with tool_calls + if decoded_message.tool_calls: + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Yield chunk with tool_calls + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + tool_calls=openai_tool_calls, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Yield final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta(), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + yield final_chunk diff --git a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 9d0295d65..f50b41f34 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -4,17 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Callable, Generator -from copy import deepcopy +from collections.abc import Callable from functools import partial from typing import Any from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) from .parallel_utils import ModelParallelProcessGroup @@ -23,12 +18,14 @@ class ModelRunner: def __init__(self, llama): self.llama = llama - # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` def __call__(self, task: Any): - if task[0] == "chat_completion": - return self.llama.chat_completion(task[1]) + task_type = task[0] + if task_type == "chat_completion": + # task[1] is [params, raw_messages] + params, raw_messages = task[1] + return self.llama.chat_completion(params, raw_messages) else: - raise ValueError(f"Unexpected task type {task[0]}") + raise ValueError(f"Unexpected task type {task_type}") def init_model_cb( @@ -78,19 +75,3 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("completion", req_obj)) - yield from gen - - def chat_completion( - self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("chat_completion", req_obj)) - yield from gen diff --git a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index bb6a1bd03..663e4793b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) log = get_logger(name=__name__, category="inference") @@ -69,10 +65,7 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ] + task: tuple[str, list] class TaskResponse(BaseModel): @@ -328,10 +321,7 @@ class ModelParallelProcessGroup: def run_inference( self, - req: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ], + req: tuple[str, list], ) -> Generator: assert not self.running, "inference already running" diff --git a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cb72aa13a..e6dcf3ae7 100644 --- a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 223497fb8..a793c499e 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -11,9 +11,7 @@ from collections.abc import AsyncIterator import litellm from llama_stack.apis.inference import ( - ChatCompletionRequest, InferenceProvider, - JsonSchemaResponseFormat, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, @@ -23,15 +21,11 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - ToolChoice, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, - get_sampling_options, prepare_openai_completion_params, ) @@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin( return schema - async def _get_params(self, request: ChatCompletionRequest) -> dict: - from typing import Any - - input_dict: dict[str, Any] = {} - - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages - ] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - # Convert to dict for manipulation - fmt_dict = dict(fmt.json_schema) - name = fmt_dict["title"] - del fmt_dict["title"] - fmt_dict["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt_dict = self._add_additional_properties_recursive(fmt_dict) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt_dict, - "strict": self.json_schema_strict, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config and (tool_choice := request.tool_config.tool_choice): - input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice - - return { - "model": request.model, - "api_key": self.get_api_key(), - "api_base": self.api_base, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index aabcb50f8..c2e6829e0 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -3,31 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import time -import uuid -import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable +from collections.abc import Iterable from typing import ( Any, ) -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) try: from openai.types.chat import ( @@ -37,84 +20,24 @@ except ImportError: from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, ImageContentItem, - InterleavedContent, TextContentItem, - TextDelta, - ToolCallDelta, - ToolCallParseStatus, _URLOrData, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, - Message, - OpenAIChatCompletion, - OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, - SystemMessage, - TokenLogProbs, - ToolChoice, - ToolConfig, - ToolResponseMessage, TopKSamplingStrategy, TopPSamplingStrategy, - UserMessage, -) -from llama_stack.apis.inference import ( - OpenAIChoice as OpenAIChatCompletionChoice, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -123,10 +46,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, - decode_assistant_message, -) logger = get_logger(name=__name__, category="providers::utils") @@ -213,345 +132,6 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens -def convert_openai_completion_logprobs( - logprobs: OpenAICompatLogprobs | None, -) -> list[TokenLogProbs] | None: - if not logprobs: - return None - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - - # Together supports logprobs with top_k=1 only. This means for each token position, - # they return only the logprobs for the selected token (vs. the top n most likely tokens). - # Here we construct the response by matching the selected token with the logprobs. - if logprobs.tokens and logprobs.token_logprobs: - return [ - TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) - ] - return None - - -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): - if logprobs is None: - return None - if isinstance(logprobs, float): - # Adapt response from Together CompletionChoicesChunk - return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - return None - - -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - text = choice.text or "" - # drop suffix if present and return stop reason as end of turn - if text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason or "stop"), - content=text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - - -def process_chat_completion_response( - response: OpenAICompatCompletionResponse, - request: ChatCompletionRequest, -) -> ChatCompletionResponse: - choice = response.choices[0] - if choice.finish_reason == "tool_calls": - if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - raise ValueError("Tool calls are not present in the response") - - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): - # If we couldn't parse a tool call, jsonify the tool calls and return them - return ChatCompletionResponse( - completion_message=CompletionMessage( - stop_reason=StopReason.end_of_turn, - content=json.dumps(tool_calls, default=lambda x: x.model_dump()), - ), - logprobs=None, - ) - else: - # Otherwise, return tool calls as normal - # Filter to only valid ToolCall objects - valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=valid_tool_calls, - stop_reason=StopReason.end_of_turn, - # Content is not optional - content="", - ), - logprobs=None, - ) - - # TODO: This does not work well with tool calls for vLLM remote provider - # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) - - # NOTE: If we do not set tools in chat-completion request, we should not - # expect the ToolCall in the response. Instead, we should return the raw - # response from the model. - if raw_message.tool_calls: - if not request.tools: - raw_message.tool_calls = [] - raw_message.content = text_from_choice(choice) - else: - # only return tool_calls if provided in the request - new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} - for t in raw_message.tool_calls: - if t.tool_name in request_tools: - new_tool_calls.append(t) - else: - logger.warning(f"Tool {t.tool_name} not found in request tools") - - if len(new_tool_calls) < len(raw_message.tool_calls): - raw_message.tool_calls = new_tool_calls - raw_message.content = text_from_choice(choice) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] - stop_reason=raw_message.stop_reason or StopReason.end_of_turn, - tool_calls=raw_message.tool_calls, - ), - logprobs=None, - ) - - -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) - - -async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - request: ChatCompletionRequest, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif stop_reason is None and finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - text = text_from_choice(choice) - if not text: - # Sometimes you get empty chunks from providers - continue - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - if ipython: - buffer += text - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=TextDelta(text=text), - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - request_tools = {t.tool_name: t for t in (request.tools or [])} - for tool_call in message.tool_calls: - if tool_call.tool_name in request_tools: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - else: - logger.warning(f"Tool {tool_call.tool_name} not found in request tools") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - # Parsing tool call failed due to tool call not being found in request tools, - # We still add the raw message text inside tool_call for responding back to the user - tool_call=buffer, - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - return { - "type": "image_url", - "image_url": { - "url": await convert_image_content_to_url(content, download=download), - }, - } - else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} - - if isinstance(message.content, list): - content = [await _convert_content(c) for c in message.content] - else: - content = [await _convert_content(message.content)] - - result = { - "role": message.role, - "content": content, - } - - if hasattr(message, "tool_calls") and message.tool_calls: - tool_calls_list = [] - for tc in message.tool_calls: - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tc.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - - tool_calls_list.append( - { - "id": tc.call_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": tc.arguments, - }, - } - ) - result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected - return result - - class UnparseableToolCall(BaseModel): """ A ToolCall with arguments that are not valid JSON. @@ -563,112 +143,6 @@ class UnparseableToolCall(BaseModel): arguments: str = "" -async def convert_message_to_openai_dict_new( - message: Message | dict, - download_images: bool = False, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) - ), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - tool_calls = [ - OpenAIChatCompletionMessageFunctionToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), - arguments=tool.arguments, # Already a JSON string, don't double-encode - ), - type="function", - ) - for tool in (message.tool_calls or []) - ] - params = {} - if tool_calls: - params["tool_calls"] = tool_calls - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - - def convert_tool_call( tool_call: ChatCompletionMessageToolCall, ) -> ToolCall | UnparseableToolCall: @@ -817,17 +291,6 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: - tool_config = ToolConfig() - if tool_choice: - try: - tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception - except ValueError: - pass - tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type - return tool_config - - def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools: list[ToolDefinition] = [] if not tools: @@ -898,40 +361,6 @@ def _convert_openai_tool_calls( ] -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> list[TokenLogProbs] | None: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs or not logprobs.content: - return None - - return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) - for content in logprobs.content - ] - - def _convert_openai_sampling_params( max_tokens: int | None = None, temperature: float | None = None, @@ -956,37 +385,6 @@ def _convert_openai_sampling_params( return sampling_params -def openai_messages_to_messages( - messages: list[OpenAIMessageParam], -) -> list[Message]: - """ - Convert a list of OpenAIChatCompletionMessage into a list of Message. - """ - converted_messages: list[Message] = [] - for message in messages: - converted_message: Message - if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "assistant": - converted_message = CompletionMessage( - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function - stop_reason=StopReason.end_of_turn, - ) - elif message.role == "tool": - converted_message = ToolResponseMessage( - role="tool", - call_id=message.tool_call_id, - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - ) - else: - raise ValueError(f"Unknown role {message.role}") - converted_messages.append(converted_message) - return converted_messages - - def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): if content is None: return "" @@ -1005,216 +403,6 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten raise ValueError(f"Unknown content type: {content}") -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union - ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], - enable_incremental_tool_calls: bool, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - """ - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - event_type = ChatCompletionResponseEventType.progress - - stop_reason = None - tool_call_idx_to_buffer = {} - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason - logprobs = getattr(choice, "logprobs", None) - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", - stacklevel=2, - ) - - if not enable_incremental_tool_calls: - for tool_call in choice.delta.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - else: - for tool_call in choice.delta.tool_calls: - idx = tool_call.index if hasattr(tool_call, "index") else 0 - - if idx not in tool_call_idx_to_buffer: - tool_call_idx_to_buffer[idx] = { - "call_id": tool_call.id, - "name": None, - "arguments": "", - "content": "", - } - - buffer = tool_call_idx_to_buffer[idx] - - if tool_call.function: - if tool_call.function.name: - buffer["name"] = tool_call.function.name - delta = f"{buffer['name']}(" - if buffer["content"] is not None: - buffer["content"] += delta - - if tool_call.function.arguments: - delta = tool_call.function.arguments - if buffer["arguments"] is not None and delta: - buffer["arguments"] += delta - if buffer["content"] is not None and delta: - buffer["content"] += delta - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - elif choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - for idx, buffer in tool_call_idx_to_buffer.items(): - logger.debug(f"toolcall_buffer[{idx}]: {buffer}") - if buffer["name"]: - delta = ")" - if buffer["content"] is not None: - buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=None, - ) - ) - - try: - parsed_tool_call = ToolCall( - call_id=buffer["call_id"] or "", - tool_name=buffer["name"] or "", - arguments=buffer["arguments"] or "", - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - except json.JSONDecodeError as e: - print(f"Failed to parse arguments: {e}") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - async def prepare_openai_completion_params(**params): async def _prepare_value(value: Any) -> Any: new_value = value @@ -1233,163 +421,6 @@ async def prepare_openai_completion_params(**params): return completion_params -class OpenAIChatCompletionToLlamaStackMixin: - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format - response_format = _convert_openai_request_response_format(response_format) - sampling_params = _convert_openai_sampling_params( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - ) - tool_config = _convert_openai_request_tool_config(tool_choice) - - tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format - if tool_config.tool_choice == ToolChoice.none: - tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type - - outstanding_responses = [] - # "n" is the number of completions to generate per prompt - n = n or 1 - for _i in range(0, n): - response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion - model_id=model, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - stream=stream, - tool_config=tool_config, - tools=tools, - ) - outstanding_responses.append(response) - - if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy - - return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( - self, model, outstanding_responses - ) - - async def _process_stream_response( - self, - model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], - ): - id = f"chatcmpl-{uuid.uuid4()}" - for i, outstanding_response in enumerate(outstanding_responses): - response = await outstanding_response - async for chunk in response: - event = chunk.event - finish_reason = ( - _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None - ) - - if isinstance(event.delta, TextDelta): - text_delta = event.delta.text - delta = OpenAIChoiceDelta(content=text_delta) - yield OpenAIChatCompletionChunk( - id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - elif isinstance(event.delta, ToolCallDelta): - if event.delta.parse_status == ToolCallParseStatus.succeeded: - tool_call = event.delta.tool_call - if isinstance(tool_call, str): - continue - - # First chunk includes full structure - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - id=tool_call.call_id, - function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy - arguments="", - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - # arguments - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - function=OpenAIChoiceDeltaToolCallFunction( - arguments=tool_call.arguments, - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - - async def _process_non_stream_response( - self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] - ) -> OpenAIChatCompletion: - choices: list[OpenAIChatCompletionChoice] = [] - for outstanding_response in outstanding_responses: - response = await outstanding_response - completion_message = response.completion_message - message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) - - choice = OpenAIChatCompletionChoice( - index=len(choices), - message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type - finish_reason=finish_reason, - ) - choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch - - return OpenAIChatCompletion( - id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible - created=int(time.time()), - model=model, - object="chat.completion", - ) - - def prepare_openai_embeddings_params( model: str, input: str | list[str], diff --git a/src/llama_stack/providers/utils/inference/prompt_adapter.py b/src/llama_stack/providers/utils/inference/prompt_adapter.py index d06b7454d..35a7b3484 100644 --- a/src/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/src/llama_stack/providers/utils/inference/prompt_adapter.py @@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, CompletionRequest, - Message, + OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, OpenAIFile, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, ResponseFormat, ResponseFormatType, - SystemMessage, - SystemMessageBehavior, ToolChoice, - ToolDefinition, - UserMessage, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import ( RawMediaItem, RawMessage, RawTextItem, - Role, StopReason, + ToolCall, + ToolDefinition, ToolPromptFormat, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( - PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, -) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal -from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="providers::utils") -class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: list[RawMessage] - - class CompletionRequestWithRawContent(CompletionRequest): content: RawContent @@ -103,28 +88,6 @@ def interleaved_content_as_str( return _process(content) -async def convert_request_to_raw( - request: ChatCompletionRequest | CompletionRequest, -) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent: - if isinstance(request, ChatCompletionRequest): - messages = [] - for m in request.messages: - content = await interleaved_content_convert_to_raw(m.content) - d = m.model_dump() - d["content"] = content - messages.append(RawMessage(**d)) - - d = request.model_dump() - d["messages"] = messages - request = ChatCompletionRequestWithRawContent(**d) - else: - d = request.model_dump() - d["content"] = await interleaved_content_convert_to_raw(request.content) - request = CompletionRequestWithRawContent(**d) - - return request - - async def interleaved_content_convert_to_raw( content: InterleavedContent, ) -> RawContent: @@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw( return await _localize_single(content) +async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage: + """Convert OpenAI message format to RawMessage format used by Llama formatters.""" + if isinstance(message, OpenAIUserMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="user", content=content) + elif isinstance(message, OpenAISystemMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="system", content=content) + elif isinstance(message, OpenAIAssistantMessageParam): + content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type] + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.function: + tool_calls.append( + ToolCall( + call_id=tc.id or "", + tool_name=tc.function.name or "", + arguments=tc.function.arguments or "{}", + ) + ) + return RawMessage(role="assistant", content=content, tool_calls=tool_calls) + elif isinstance(message, OpenAIToolMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="tool", content=content) + else: + # Handle OpenAIDeveloperMessageParam if needed + raise ValueError(f"Unsupported message type: {type(message)}") + + def content_has_media(content: InterleavedContent): def _has_media_content(c): return isinstance(c, ImageContentItem) @@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent): return _has_media_content(content) -def messages_have_media(messages: list[Message]): - return any(content_has_media(m.content) for m in messages) - - -def request_has_media(request: ChatCompletionRequest | CompletionRequest): - if isinstance(request, ChatCompletionRequest): - return messages_have_media(request.messages) - else: - return content_has_media(request.content) - - async def localize_image_content(uri: str) -> tuple[bytes, str] | None: if uri.startswith("http"): async with httpx.AsyncClient() as client: @@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return formatter.tokenizer.decode(model_input.tokens) - - -async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str -) -> tuple[str, int]: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return ( - formatter.tokenizer.decode(model_input.tokens), - len(model_input.tokens), - ) - - -def chat_completion_request_to_messages( - request: ChatCompletionRequest, - llama_model: str, -) -> list[Message]: - """Reads chat completion request and augments the messages to handle tools. - For eg. for llama_3_1, add system message with the appropriate tools or - add user messsage for custom tools, etc. - """ - assert llama_model is not None, "llama_model is required" - model = resolve_model(llama_model) - if model is None: - log.error(f"Could not resolve model {llama_model}") - return request.messages - - allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models] - if model.descriptor() not in descriptors: - log.error(f"Unsupported inference model? {model.descriptor()}") - return request.messages - - if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) - ): - # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_1(request) - elif model.model_family in ( - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ): - # llama3.2, llama3.3 follow the same tool prompt format - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) - elif model.model_family == ModelFamily.llama4: - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) - else: - messages = request.messages - - if fmt_prompt := response_format_prompt(request.response_format): - messages.append(UserMessage(content=fmt_prompt)) - - return messages - - def response_format_prompt(fmt: ResponseFormat | None): if not fmt: return None @@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None): raise ValueError(f"Unknown response format {fmt.type}") -def augment_messages_for_tools_llama_3_1( - request: ChatCompletionRequest, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - messages = [] - - default_gen = SystemDefaultGenerator() - default_template = default_gen.gen() - - sys_content = "" - - tool_template = None - if request.tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(request.tools) - - sys_content += tool_template.render() - sys_content += "\n" - - sys_content += default_template.render() - - if existing_system_message: - # TODO: this fn is needed in many places - def _process(c): - if isinstance(c, str): - return c - else: - return "" - - sys_content += "\n" - - if isinstance(existing_system_message.content, str): - sys_content += _process(existing_system_message.content) - elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages.append(SystemMessage(content=sys_content)) - - has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json - if fmt == ToolPromptFormat.json: - tool_gen = JsonCustomToolGenerator() - elif fmt == ToolPromptFormat.function_tag: - tool_gen = FunctionTagCustomToolGenerator() - else: - raise ValueError(f"Non supported ToolPromptFormat {fmt}") - - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] - custom_template = tool_gen.gen(custom_tools) - messages.append(UserMessage(content=custom_template.render())) - - # Add back existing messages from the request - messages += existing_messages - - return messages - - -def augment_messages_for_tools_llama( - request: ChatCompletionRequest, - custom_tool_prompt_generator, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - sys_content = "" - custom_tools, builtin_tools = [], [] - for t in request.tools: - if isinstance(t.tool_name, str): - custom_tools.append(t) - else: - builtin_tools.append(t) - - if builtin_tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(builtin_tools) - - sys_content += tool_template.render() - sys_content += "\n" - - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] - if custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list - if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") - - system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: - system_prompt = existing_system_message.content - - tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) - - sys_content += tool_template.render() - sys_content += "\n" - - if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools - ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] - return messages - - def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py deleted file mode 100644 index d31426135..000000000 --- a/tests/unit/models/test_prompt_adapter.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - StopReason, - SystemMessage, - SystemMessageBehavior, - ToolCall, - ToolConfig, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, - chat_completion_request_to_prompt, - interleaved_content_as_str, -) - -MODEL = "Llama3.1-8B-Instruct" -MODEL3_2 = "Llama3.2-3B-Instruct" - - -async def test_system_default(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - - -async def test_system_builtin_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - -async def test_system_custom_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_system_custom_and_builtin(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_completion_message_encoding(): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments='{"param1": "value1"}', # arguments must be a JSON string - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '[custom1(param1="value1")]' in prompt - - request.model = MODEL - request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt - - -async def test_user_provided_system_message(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_builtin_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools_with_template(): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "You are a pirate" in interleaved_content_as_str(messages[0].content) - # function description is present in the system prompt - assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content diff --git a/tests/unit/providers/inline/inference/__init__.py b/tests/unit/providers/inline/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/inference/test_meta_reference.py b/tests/unit/providers/inline/inference/test_meta_reference.py new file mode 100644 index 000000000..381836397 --- /dev/null +++ b/tests/unit/providers/inline/inference/test_meta_reference.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.providers.inline.inference.meta_reference.model_parallel import ( + ModelRunner, +) + + +class TestModelRunner: + """Test ModelRunner task dispatching for model-parallel inference.""" + + def test_chat_completion_task_dispatch(self): + """Verify ModelRunner correctly dispatches chat_completion tasks.""" + # Create a mock generator + mock_generator = Mock() + mock_generator.chat_completion = Mock(return_value=iter([])) + + runner = ModelRunner(mock_generator) + + # Create a chat_completion task + fake_params = {"model": "test"} + fake_messages = [{"role": "user", "content": "test"}] + task = ("chat_completion", [fake_params, fake_messages]) + + # Execute task + runner(task) + + # Verify chat_completion was called with correct arguments + mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages) + + def test_invalid_task_type_raises_error(self): + """Verify ModelRunner rejects invalid task types.""" + mock_generator = Mock() + runner = ModelRunner(mock_generator) + + with pytest.raises(ValueError, match="Unexpected task type"): + runner(("invalid_task", [])) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 922d7f61f..622302630 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_stack.apis.inference import CompletionMessage, UserMessage +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.shields import Shield -from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter @@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Mock Guardrails API response mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post): adapter.shield_store.get_shield.return_value = None messages = [ - UserMessage(role="user", content="Hello, how are you?"), + OpenAIUserMessageParam(content="Hello, how are you?"), ] with pytest.raises(ValueError): @@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Running the shield should raise an exception messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py deleted file mode 100644 index c200c4395..000000000 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -from pydantic import ValidationError - -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - CompletionMessage, - OpenAIAssistantMessageParam, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, - SystemMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - convert_message_to_openai_dict_new, - openai_messages_to_messages, -) - - -async def test_convert_message_to_openai_dict(): - message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") - assert await convert_message_to_openai_dict(message) == { - "role": "user", - "content": [{"type": "text", "text": "Hello, world!"}], - } - - -# Test convert_message_to_openai_dict with a tool call -async def test_convert_message_to_openai_dict_with_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_convert_message_to_openai_dict_with_builtin_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ - ToolCall( - call_id="123", - tool_name=BuiltinTool.brave_search, - arguments='{"foo": "bar"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_openai_messages_to_messages_with_content_str(): - openai_messages = [ - OpenAISystemMessageParam(content="system message"), - OpenAIUserMessageParam(content="user message"), - OpenAIAssistantMessageParam(content="assistant message"), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content == "system message" - assert llama_messages[1].content == "user message" - assert llama_messages[2].content == "assistant message" - - -async def test_openai_messages_to_messages_with_content_list(): - openai_messages = [ - OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), - OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]), - OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content[0].text == "system message" - assert llama_messages[1].content[0].text == "user message" - assert llama_messages[2].content[0].text == "assistant message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_string(message_class, kwargs): - """Test that messages accept string text content.""" - msg = message_class(content="Test message", **kwargs) - assert msg.content == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_list(message_class, kwargs): - """Test that messages accept list of text content parts.""" - content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] - msg = message_class(content=content_list, **kwargs) - assert len(msg.content) == 1 - assert msg.content[0].text == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_rejects_images(message_class, kwargs): - """Test that system, assistant, developer, and tool messages reject image content.""" - with pytest.raises(ValidationError): - message_class( - content=[ - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) - ], - **kwargs, - ) - - -def test_user_message_accepts_images(): - """Test that user messages accept image content (unlike other message types).""" - # List with images should work - msg = OpenAIUserMessageParam( - content=[ - OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), - ] - ) - assert len(msg.content) == 2 - assert msg.content[0].text == "Describe this image:" - assert msg.content[1].image_url.url == "http://example.com/image.jpg" - - -async def test_convert_message_to_openai_dict_new_user_message(): - """Test convert_message_to_openai_dict_new with UserMessage.""" - message = UserMessage(content="Hello, world!", role="user") - result = await convert_message_to_openai_dict_new(message) - - assert result["role"] == "user" - assert result["content"] == "Hello, world!" - - -async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): - """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" - message = CompletionMessage( - content="I'll help you find the weather.", - tool_calls=[ - ToolCall( - call_id="call_123", - tool_name="get_weather", - arguments='{"city": "Sligo"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - result = await convert_message_to_openai_dict_new(message) - - # This would have failed with "Cannot instantiate typing.Union" before the fix - assert result["role"] == "assistant" - assert result["content"] == "I'll help you find the weather." - assert "tool_calls" in result - assert result["tool_calls"] is not None - assert len(result["tool_calls"]) == 1 - - tool_call = result["tool_calls"][0] - assert tool_call.id == "call_123" - assert tool_call.type == "function" - assert tool_call.function.name == "get_weather" - assert tool_call.function.arguments == '{"city": "Sligo"}' diff --git a/tests/unit/providers/utils/inference/test_prompt_adapter.py b/tests/unit/providers/utils/inference/test_prompt_adapter.py new file mode 100644 index 000000000..62c8db74d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_prompt_adapter.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.models.llama.datatypes import RawTextItem +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, +) + + +class TestConvertOpenAIMessageToRawMessage: + """Test conversion of OpenAI message types to RawMessage format.""" + + async def test_user_message_conversion(self): + msg = OpenAIUserMessageParam(role="user", content="Hello world") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "user" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hello world" + + async def test_assistant_message_conversion(self): + msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "assistant" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hi there!" + assert raw_msg.tool_calls == [] From 97ccfb5e626919956aee0bf0e890a7196e8af6a2 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Mon, 10 Nov 2025 18:57:17 -0500 Subject: [PATCH 06/15] refactor: inspect routes now shows all non-deprecated APIs (#4116) # What does this PR do? the inspect API lacked any mechanism to get all non-deprecated APIs (v1, v1alpha, v1beta) change default to this behavior 'v1' filter can be used for user' wanting a list of stable APIs ## Test Plan 1. pull the PR 2. launch a LLS server 3. run `curl http://beanlab3.bss.redhat.com:8321/v1/inspect/routes` 4. note there are APIs for `v1`, `v1alpha`, and `v1beta` but no deprecated APIs Signed-off-by: Nathan Weinberg --- client-sdks/stainless/openapi.yml | 2 +- docs/static/llama-stack-spec.yaml | 2 +- docs/static/stainless-llama-stack-spec.yaml | 2 +- src/llama_stack/apis/inspect/inspect.py | 2 +- src/llama_stack/core/inspect.py | 5 ++--- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 58ebaa8ae..9f3ef15b5 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 135ae910f..ce8708b68 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -960,7 +960,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 58ebaa8ae..9f3ef15b5 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/src/llama_stack/apis/inspect/inspect.py b/src/llama_stack/apis/inspect/inspect.py index 4e0e2548b..235abb124 100644 --- a/src/llama_stack/apis/inspect/inspect.py +++ b/src/llama_stack/apis/inspect/inspect.py @@ -76,7 +76,7 @@ class Inspect(Protocol): List all available API routes with their methods and implementing providers. - :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes. + :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes. :returns: Response containing information about all available routes. """ ... diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 6352af00f..07b51128f 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -15,7 +15,6 @@ from llama_stack.apis.inspect import ( RouteInfo, VersionInfo, ) -from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis from llama_stack.core.server.routes import get_all_api_routes @@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect): # Helper function to determine if a route should be included based on api_filter def should_include_route(webmethod) -> bool: if api_filter is None: - # Default: only non-deprecated v1 APIs - return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1 + # Default: only non-deprecated APIs + return not webmethod.deprecated elif api_filter == "deprecated": # Special filter: show deprecated routes regardless of their actual level return bool(webmethod.deprecated) From e5a55f36776575f89145be1004935e8113aed90a Mon Sep 17 00:00:00 2001 From: paulengineer <154521137+paulengineer@users.noreply.github.com> Date: Tue, 11 Nov 2025 12:49:03 +0000 Subject: [PATCH 07/15] docs: use 'uv pip' to avoid pitfalls of using 'pip' in virtual environment (#4122) # What does this PR do? In the **Detailed Tutorial**, at **Step 3**, the **Install with venv** option creates a new virtual environment `client`, activates it then attempts to install the llama-stack-client using pip. ``` uv venv client --python 3.12 source client/bin/activate pip install llama-stack-client <- this is the problematic line ``` However, the pip command will likely fail because the `uv venv` command doesn't, by default, include adding the pip command to the virtual environment that is created. The pip command will error either because pip doesn't exist at all, or, if the pip command does exist outside of the virtual environment, return a different error message. The latter may be unclear to the user why it is failing. This PR changes 'pip' to 'uv pip', allowing the install action to function in the virtual environment as intended, and without the need for pip to be installed. ## Test Plan 1. Use linux or WSL (virtual environments on Windows use `Scripts` folder instead of `bin` [virtualenv #993ba13](https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1) which doesn't align with the tutorial) 2. Clone the `llama-stack` repo 3. Run the following and verify success: ``` uv venv client --python 3.12 source client/bin/activate ``` 5. Run the updated command: ``` uv pip install llama-stack-client ``` 6. Observe the console output confirms that the virtual environment `client` was used: > Using Python 3.12.3 environment at: **client** --- docs/docs/getting_started/detailed_tutorial.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/getting_started/detailed_tutorial.mdx b/docs/docs/getting_started/detailed_tutorial.mdx index 623301d0d..2816f67a2 100644 --- a/docs/docs/getting_started/detailed_tutorial.mdx +++ b/docs/docs/getting_started/detailed_tutorial.mdx @@ -144,7 +144,7 @@ source .venv/bin/activate ```bash uv venv client --python 3.12 source client/bin/activate -pip install llama-stack-client +uv pip install llama-stack-client ``` From 71b328fc4bddd672a0665a15484ee64e5464c62a Mon Sep 17 00:00:00 2001 From: ehhuang Date: Tue, 11 Nov 2025 10:40:31 -0800 Subject: [PATCH 08/15] chore(ui): add npm package and dockerfile (#4100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - sets up package.json for npm `llama-stack-ui` package (will update llama-stack-ops) - adds dockerfile for UI docker image ## Test Plan npx: npm build && npm pack LLAMA_STACK_UI_PORT=8322 npx /Users/erichuang/projects/ui/src/llama_stack_ui/llama-stack-ui-0.4.0-alpha.2.tgz docker: cd src/llama_stack_ui docker build . -f Dockerfile --tag test_ui --no-cache ❯ docker run -p 8322:8322 \ -e LLAMA_STACK_UI_PORT=8322 \ test_ui:latest --- docs/docs/distributions/index.mdx | 1 + docs/docs/distributions/llama_stack_ui.mdx | 109 +++++++++++++++++++++ docs/sidebars.ts | 1 + src/llama_stack_ui/.dockerignore | 20 ++++ src/llama_stack_ui/Containerfile | 18 ++++ src/llama_stack_ui/bin/cli.js | 34 +++++++ src/llama_stack_ui/next.config.ts | 8 +- src/llama_stack_ui/package-lock.json | 16 +-- src/llama_stack_ui/package.json | 30 +++++- src/llama_stack_ui/scripts/postbuild.js | 40 ++++++++ 10 files changed, 264 insertions(+), 13 deletions(-) create mode 100644 docs/docs/distributions/llama_stack_ui.mdx create mode 100644 src/llama_stack_ui/.dockerignore create mode 100644 src/llama_stack_ui/Containerfile create mode 100755 src/llama_stack_ui/bin/cli.js create mode 100644 src/llama_stack_ui/scripts/postbuild.js diff --git a/docs/docs/distributions/index.mdx b/docs/docs/distributions/index.mdx index 0149f143f..ebf4bd6ce 100644 --- a/docs/docs/distributions/index.mdx +++ b/docs/docs/distributions/index.mdx @@ -19,3 +19,4 @@ This section provides an overview of the distributions available in Llama Stack. - **[Starting Llama Stack Server](./starting_llama_stack_server.mdx)** - How to run distributions - **[Importing as Library](./importing_as_library.mdx)** - Use distributions in your code - **[Configuration Reference](./configuration.mdx)** - Configuration file format details +- **[Llama Stack UI](./llama_stack_ui.mdx)** - Web-based user interface for interacting with Llama Stack servers diff --git a/docs/docs/distributions/llama_stack_ui.mdx b/docs/docs/distributions/llama_stack_ui.mdx new file mode 100644 index 000000000..7ba47ea4d --- /dev/null +++ b/docs/docs/distributions/llama_stack_ui.mdx @@ -0,0 +1,109 @@ +--- +title: Llama Stack UI +description: Web-based user interface for interacting with Llama Stack servers +sidebar_label: Llama Stack UI +sidebar_position: 8 +--- + +# Llama Stack UI + +The Llama Stack UI is a web-based interface for interacting with Llama Stack servers. Built with Next.js and React, it provides a visual way to work with agents, manage resources, and view logs. + +## Features + +- **Logs & Monitoring**: View chat completions, agent responses, and vector store activity +- **Vector Stores**: Create and manage vector databases for RAG (Retrieval-Augmented Generation) workflows +- **Prompt Management**: Create and manage reusable prompts + +## Prerequisites + +You need a running Llama Stack server. The UI is a client that connects to the Llama Stack backend. + +If you don't have a Llama Stack server running yet, see the [Starting Llama Stack Server](../getting_started/starting_llama_stack_server.mdx) guide. + +## Running the UI + +### Option 1: Using npx (Recommended for Quick Start) + +The fastest way to get started is using `npx`: + +```bash +npx llama-stack-ui +``` + +This will start the UI server on `http://localhost:8322` (default port). + +### Option 2: Using Docker + +Run the UI in a container: + +```bash +docker run -p 8322:8322 llamastack/ui +``` + +Access the UI at `http://localhost:8322`. + +## Environment Variables + +The UI can be configured using the following environment variables: + +| Variable | Description | Default | +|----------|-------------|---------| +| `LLAMA_STACK_BACKEND_URL` | URL of your Llama Stack server | `http://localhost:8321` | +| `LLAMA_STACK_UI_PORT` | Port for the UI server | `8322` | + +If the Llama Stack server is running with authentication enabled, you can configure the UI to use it by setting the following environment variables: + +| Variable | Description | Default | +|----------|-------------|---------| +| `NEXTAUTH_URL` | NextAuth URL for authentication | `http://localhost:8322` | +| `GITHUB_CLIENT_ID` | GitHub OAuth client ID (optional, for authentication) | - | +| `GITHUB_CLIENT_SECRET` | GitHub OAuth client secret (optional, for authentication) | - | + +### Setting Environment Variables + +#### For npx: + +```bash +LLAMA_STACK_BACKEND_URL=http://localhost:8321 \ +LLAMA_STACK_UI_PORT=8080 \ +npx llama-stack-ui +``` + +#### For Docker: + +```bash +docker run -p 8080:8080 \ + -e LLAMA_STACK_BACKEND_URL=http://localhost:8321 \ + -e LLAMA_STACK_UI_PORT=8080 \ + llamastack/ui +``` + +## Using the UI + +### Managing Resources + +- **Vector Stores**: Create vector databases for RAG workflows, view stored documents and embeddings +- **Prompts**: Create and manage reusable prompt templates +- **Chat Completions**: View history of chat interactions +- **Responses**: Browse detailed agent responses and tool calls + +## Development + +If you want to run the UI from source for development: + +```bash +# From the project root +cd src/llama_stack_ui + +# Install dependencies +npm install + +# Set environment variables +export LLAMA_STACK_BACKEND_URL=http://localhost:8321 + +# Start the development server +npm run dev +``` + +The development server will start on `http://localhost:8322` with hot reloading enabled. diff --git a/docs/sidebars.ts b/docs/sidebars.ts index 641c2eed3..7b4ac5ac8 100644 --- a/docs/sidebars.ts +++ b/docs/sidebars.ts @@ -57,6 +57,7 @@ const sidebars: SidebarsConfig = { 'distributions/importing_as_library', 'distributions/configuration', 'distributions/starting_llama_stack_server', + 'distributions/llama_stack_ui', { type: 'category', label: 'Self-Hosted Distributions', diff --git a/src/llama_stack_ui/.dockerignore b/src/llama_stack_ui/.dockerignore new file mode 100644 index 000000000..e3d1daae6 --- /dev/null +++ b/src/llama_stack_ui/.dockerignore @@ -0,0 +1,20 @@ +.git +.gitignore +.env.local +.env.*.local +.next +node_modules +npm-debug.log +*.md +.DS_Store +.vscode +.idea +playwright-report +e2e +jest.config.ts +jest.setup.ts +eslint.config.mjs +.prettierrc +.prettierignore +.nvmrc +playwright.config.ts diff --git a/src/llama_stack_ui/Containerfile b/src/llama_stack_ui/Containerfile new file mode 100644 index 000000000..6aea3dbfd --- /dev/null +++ b/src/llama_stack_ui/Containerfile @@ -0,0 +1,18 @@ +FROM node:22.5.1-alpine + +ENV NODE_ENV=production + +# Install dumb-init for proper signal handling +RUN apk add --no-cache dumb-init + +# Create non-root user for security +RUN addgroup --system --gid 1001 nodejs +RUN adduser --system --uid 1001 nextjs + +# Install llama-stack-ui from npm +RUN npm install -g llama-stack-ui + +USER nextjs + +ENTRYPOINT ["dumb-init", "--"] +CMD ["llama-stack-ui"] diff --git a/src/llama_stack_ui/bin/cli.js b/src/llama_stack_ui/bin/cli.js new file mode 100755 index 000000000..6069d2f22 --- /dev/null +++ b/src/llama_stack_ui/bin/cli.js @@ -0,0 +1,34 @@ +#!/usr/bin/env node + +const { spawn } = require('child_process'); +const path = require('path'); + +const port = process.env.LLAMA_STACK_UI_PORT || 8322; +const uiDir = path.resolve(__dirname, '..'); +const serverPath = path.join(uiDir, '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', 'server.js'); +const serverDir = path.dirname(serverPath); + +console.log(`Starting Llama Stack UI on http://localhost:${port}`); + +const child = spawn(process.execPath, [serverPath], { + cwd: serverDir, + stdio: 'inherit', + env: { + ...process.env, + PORT: port, + }, +}); + +process.on('SIGINT', () => { + child.kill('SIGINT'); + process.exit(0); +}); + +process.on('SIGTERM', () => { + child.kill('SIGTERM'); + process.exit(0); +}); + +child.on('exit', (code) => { + process.exit(code); +}); diff --git a/src/llama_stack_ui/next.config.ts b/src/llama_stack_ui/next.config.ts index e9ffa3083..9f4a74eca 100644 --- a/src/llama_stack_ui/next.config.ts +++ b/src/llama_stack_ui/next.config.ts @@ -1,7 +1,13 @@ import type { NextConfig } from "next"; const nextConfig: NextConfig = { - /* config options here */ + typescript: { + ignoreBuildErrors: true, + }, + output: "standalone", + images: { + unoptimized: true, + }, }; export default nextConfig; diff --git a/src/llama_stack_ui/package-lock.json b/src/llama_stack_ui/package-lock.json index 14e34b720..aa8b2ac26 100644 --- a/src/llama_stack_ui/package-lock.json +++ b/src/llama_stack_ui/package-lock.json @@ -1,12 +1,13 @@ { - "name": "ui", - "version": "0.1.0", + "name": "llama-stack-ui", + "version": "0.4.0-alpha.1", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "ui", - "version": "0.1.0", + "name": "llama-stack-ui", + "version": "0.4.0-alpha.1", + "license": "MIT", "dependencies": { "@radix-ui/react-collapsible": "^1.1.12", "@radix-ui/react-dialog": "^1.1.15", @@ -20,7 +21,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "framer-motion": "^12.23.24", - "llama-stack-client": "github:llamastack/llama-stack-client-typescript", + "llama-stack-client": "^0.3.1", "lucide-react": "^0.545.0", "next": "15.5.4", "next-auth": "^4.24.11", @@ -9684,8 +9685,9 @@ "license": "MIT" }, "node_modules/llama-stack-client": { - "version": "0.4.0-alpha.1", - "resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f", + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.1.tgz", + "integrity": "sha512-4aYoF2aAQiBSfxyZEtczeQmJn8q9T22ePDqGhR+ej5RG6a8wvl5B3v7ZoKuFkft+vcP/kbJ58GQZEPLekxekZA==", "license": "MIT", "dependencies": { "@types/node": "^18.11.18", diff --git a/src/llama_stack_ui/package.json b/src/llama_stack_ui/package.json index fb7dbee75..41afc9a11 100644 --- a/src/llama_stack_ui/package.json +++ b/src/llama_stack_ui/package.json @@ -1,11 +1,31 @@ { - "name": "ui", - "version": "0.1.0", - "private": true, + "name": "llama-stack-ui", + "version": "0.4.0-alpha.4", + "description": "Web UI for Llama Stack", + "license": "MIT", + "author": "Llama Stack ", + "repository": { + "type": "git", + "url": "https://github.com/llamastack/llama-stack.git", + "directory": "llama_stack_ui" + }, + "bin": { + "llama-stack-ui": "bin/cli.js" + }, + "files": [ + "bin", + ".next", + "public", + "next.config.ts", + "instrumentation.ts", + "tsconfig.json", + "package.json" + ], "scripts": { "dev": "next dev --turbopack --port ${LLAMA_STACK_UI_PORT:-8322}", - "build": "next build", + "build": "next build && node scripts/postbuild.js", "start": "next start", + "prepublishOnly": "npm run build", "lint": "next lint", "format": "prettier --write \"./**/*.{ts,tsx}\"", "format:check": "prettier --check \"./**/*.{ts,tsx}\"", @@ -25,7 +45,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "framer-motion": "^12.23.24", - "llama-stack-client": "github:llamastack/llama-stack-client-typescript", + "llama-stack-client": "^0.3.1", "lucide-react": "^0.545.0", "next": "15.5.4", "next-auth": "^4.24.11", diff --git a/src/llama_stack_ui/scripts/postbuild.js b/src/llama_stack_ui/scripts/postbuild.js new file mode 100644 index 000000000..4b4dbdf5d --- /dev/null +++ b/src/llama_stack_ui/scripts/postbuild.js @@ -0,0 +1,40 @@ +const fs = require('fs'); +const path = require('path'); + +// Copy public directory to standalone +const publicSrc = path.join(__dirname, '..', 'public'); +const publicDest = path.join(__dirname, '..', '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', 'public'); + +if (fs.existsSync(publicSrc) && !fs.existsSync(publicDest)) { + console.log('Copying public directory to standalone...'); + copyDir(publicSrc, publicDest); +} + +// Copy .next/static to standalone +const staticSrc = path.join(__dirname, '..', '.next', 'static'); +const staticDest = path.join(__dirname, '..', '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', '.next', 'static'); + +if (fs.existsSync(staticSrc) && !fs.existsSync(staticDest)) { + console.log('Copying .next/static to standalone...'); + copyDir(staticSrc, staticDest); +} + +function copyDir(src, dest) { + if (!fs.existsSync(dest)) { + fs.mkdirSync(dest, { recursive: true }); + } + + const files = fs.readdirSync(src); + files.forEach((file) => { + const srcFile = path.join(src, file); + const destFile = path.join(dest, file); + + if (fs.statSync(srcFile).isDirectory()) { + copyDir(srcFile, destFile); + } else { + fs.copyFileSync(srcFile, destFile); + } + }); +} + +console.log('Postbuild complete!'); From 6ca2a67a9f1bfec5c4e520b6d82407d4d8ecd914 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 12 Nov 2025 04:09:14 -0500 Subject: [PATCH 09/15] chore: remove dead code (#4125) # What does this PR do? build_image is not used because `llama stack build` is gone. Remove it. Signed-off-by: Charlie Doern --- src/llama_stack/core/build.py | 65 ----------------------------------- 1 file changed, 65 deletions(-) diff --git a/src/llama_stack/core/build.py b/src/llama_stack/core/build.py index 2ceb9e9be..fb3a22109 100644 --- a/src/llama_stack/core/build.py +++ b/src/llama_stack/core/build.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import importlib.resources import sys from pydantic import BaseModel @@ -12,9 +11,6 @@ from termcolor import cprint from llama_stack.core.datatypes import BuildConfig from llama_stack.core.distribution import get_provider_registry -from llama_stack.core.external import load_external_apis -from llama_stack.core.utils.exec import run_command -from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.distributions.template import DistributionTemplate from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig): for special_dep in special_deps: cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr) print() - - -def build_image( - build_config: BuildConfig, - image_name: str, - distro_or_config: str, - run_config: str | None = None, -): - container_base = build_config.distribution_spec.container_image or "python:3.12-slim" - - normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config) - normal_deps += SERVER_DEPENDENCIES - if build_config.external_apis_dir: - external_apis = load_external_apis(build_config) - if external_apis: - for _, api_spec in external_apis.items(): - normal_deps.extend(api_spec.pip_packages) - - if build_config.image_type == LlamaStackImageType.CONTAINER.value: - script = str(importlib.resources.files("llama_stack") / "core/build_container.sh") - args = [ - script, - "--distro-or-config", - distro_or_config, - "--image-name", - image_name, - "--container-base", - container_base, - "--normal-deps", - " ".join(normal_deps), - ] - # When building from a config file (not a template), include the run config path in the - # build arguments - if run_config is not None: - args.extend(["--run-config", run_config]) - else: - script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh") - args = [ - script, - "--env-name", - str(image_name), - "--normal-deps", - " ".join(normal_deps), - ] - - # Always pass both arguments, even if empty, to maintain consistent positional arguments - if special_deps: - args.extend(["--optional-deps", "#".join(special_deps)]) - if external_provider_deps: - args.extend( - ["--external-provider-deps", "#".join(external_provider_deps)] - ) # the script will install external provider module, get its deps, and install those too. - - return_code = run_command(args) - - if return_code != 0: - log.error( - f"Failed to build target {image_name} with return code {return_code}", - ) - - return return_code From 539b9c08f38269a80aa5f79cc348b5a2a6032ba3 Mon Sep 17 00:00:00 2001 From: Akshay Ghodake Date: Wed, 12 Nov 2025 14:54:19 +0530 Subject: [PATCH 10/15] chore(deps): update pypdf to fix DoS vulnerabilities (#4121) Update pypdf dependency to address vulnerabilities causing potential denial of service through infinite loops or excessive memory usage when handling malicious PDFs. The update remains fully backward compatible, with no changes to the PdfReader API. # What does this PR do? Fixes #4120 ## Test Plan Co-authored-by: Francisco Arceo --- pyproject.toml | 4 ++-- uv.lock | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 653c6d613..e6808af8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ unit = [ "aiosqlite", "aiohttp", "psycopg2-binary>=2.9.0", - "pypdf", + "pypdf>=6.1.3", "mcp", "chardet", "sqlalchemy", @@ -135,7 +135,7 @@ test = [ "torchvision>=0.21.0", "chardet", "psycopg2-binary>=2.9.0", - "pypdf", + "pypdf>=6.1.3", "mcp", "datasets>=4.0.0", "autoevals", diff --git a/uv.lock b/uv.lock index ba9a862a3..f1808f005 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -2166,7 +2166,7 @@ test = [ { name = "milvus-lite", specifier = ">=2.5.0" }, { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pymilvus", specifier = ">=2.6.1" }, - { name = "pypdf" }, + { name = "pypdf", specifier = ">=6.1.3" }, { name = "qdrant-client" }, { name = "requests" }, { name = "sqlalchemy" }, @@ -2219,7 +2219,7 @@ unit = [ { name = "moto", extras = ["s3"], specifier = ">=5.1.10" }, { name = "ollama" }, { name = "psycopg2-binary", specifier = ">=2.9.0" }, - { name = "pypdf" }, + { name = "pypdf", specifier = ">=6.1.3" }, { name = "sqlalchemy" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlite-vec" }, @@ -3973,11 +3973,11 @@ wheels = [ [[package]] name = "pypdf" -version = "5.9.0" +version = "6.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/3a/584b97a228950ed85aec97c811c68473d9b8d149e6a8c155668287cf1a28/pypdf-5.9.0.tar.gz", hash = "sha256:30f67a614d558e495e1fbb157ba58c1de91ffc1718f5e0dfeb82a029233890a1", size = 5035118, upload-time = "2025-07-27T14:04:52.364Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/2b/8795ec0378384000b0a37a2b5e6d67fa3d84802945aa2c612a78a784d7d4/pypdf-6.2.0.tar.gz", hash = "sha256:46b4d8495d68ae9c818e7964853cd9984e6a04c19fe7112760195395992dce48", size = 5272001, upload-time = "2025-11-09T11:10:41.911Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/d9/6cff57c80a6963e7dd183bf09e9f21604a77716644b1e580e97b259f7612/pypdf-5.9.0-py3-none-any.whl", hash = "sha256:be10a4c54202f46d9daceaa8788be07aa8cd5ea8c25c529c50dd509206382c35", size = 313193, upload-time = "2025-07-27T14:04:50.53Z" }, + { url = "https://files.pythonhosted.org/packages/de/ba/743ddcaf1a8fb439342399645921e2cf2c600464cba5531a11f1cc0822b6/pypdf-6.2.0-py3-none-any.whl", hash = "sha256:4c0f3e62677217a777ab79abe22bf1285442d70efabf552f61c7a03b6f5c569f", size = 326592, upload-time = "2025-11-09T11:10:39.941Z" }, ] [[package]] From 63137f9af1fde09eee62a0b28798297a9166c42e Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Wed, 12 Nov 2025 17:39:21 +0100 Subject: [PATCH 11/15] chore(stainless): add config for file header (#4126) # What does this PR do? This PR adds Stainless config to specify the Meta copyright file header for generated files. Doing it via config instead of custom code will reduce the probability of git conflict. ## Test Plan - review preview builds --- client-sdks/stainless/config.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client-sdks/stainless/config.yml b/client-sdks/stainless/config.yml index ab9342c49..c61b53654 100644 --- a/client-sdks/stainless/config.yml +++ b/client-sdks/stainless/config.yml @@ -463,6 +463,12 @@ resources: settings: license: MIT unwrap_response_fields: [data] + file_header: | + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the terms described in the LICENSE file in + the root directory of this source tree. openapi: transformations: From 37853ca5581a832ef7db9a130b2064ae705bcce3 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 12 Nov 2025 12:17:13 -0500 Subject: [PATCH 12/15] fix(tests): add OpenAI client connection cleanup to prevent CI hangs (#4119) # What does this PR do? Add explicit connection cleanup and shorter timeouts to OpenAI client fixtures. Fixes CI deadlock after 25+ tests due to connection pool exhaustion. Also adds 60s timeout to test_conversation_context_loading as safety net. ## Test Plan tests pass Signed-off-by: Charlie Doern --- tests/integration/fixtures/common.py | 8 +++++++- tests/integration/responses/fixtures/fixtures.py | 10 +++++++++- .../responses/test_conversation_responses.py | 8 +++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index d5e4c15f7..407564c15 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -323,7 +323,13 @@ def require_server(llama_stack_client): @pytest.fixture(scope="session") def openai_client(llama_stack_client, require_server): base_url = f"{llama_stack_client.base_url}/v1" - return OpenAI(base_url=base_url, api_key="fake") + client = OpenAI(base_url=base_url, api_key="fake", max_retries=0, timeout=30.0) + yield client + # Cleanup: close HTTP connections + try: + client.close() + except Exception: + pass @pytest.fixture(params=["openai_client", "client_with_models"]) diff --git a/tests/integration/responses/fixtures/fixtures.py b/tests/integration/responses/fixtures/fixtures.py index dbf67e138..b06117b98 100644 --- a/tests/integration/responses/fixtures/fixtures.py +++ b/tests/integration/responses/fixtures/fixtures.py @@ -115,7 +115,15 @@ def openai_client(base_url, api_key, provider): client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) return client - return OpenAI( + client = OpenAI( base_url=base_url, api_key=api_key, + max_retries=0, + timeout=30.0, ) + yield client + # Cleanup: close HTTP connections + try: + client.close() + except Exception: + pass diff --git a/tests/integration/responses/test_conversation_responses.py b/tests/integration/responses/test_conversation_responses.py index ef7ea7c4e..babb77793 100644 --- a/tests/integration/responses/test_conversation_responses.py +++ b/tests/integration/responses/test_conversation_responses.py @@ -65,8 +65,14 @@ class TestConversationResponses: conversation_items = openai_client.conversations.items.list(conversation.id) assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages + @pytest.mark.timeout(60, method="thread") def test_conversation_context_loading(self, openai_client, text_model_id): - """Test that conversation context is properly loaded for responses.""" + """Test that conversation context is properly loaded for responses. + + Note: 60s timeout added due to CI-specific deadlock in pytest/OpenAI client/httpx + after running 25+ tests. Hangs before first HTTP request is made. Works fine locally. + Investigation needed: connection pool exhaustion or event loop state issue. + """ conversation = openai_client.conversations.create( items=[ {"type": "message", "role": "user", "content": "My name is Alice. I like to eat apples."}, From eb3f9ac2781d0079eb65ea14b77296fcd3d317d4 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Wed, 12 Nov 2025 12:59:48 -0500 Subject: [PATCH 13/15] feat: allow returning embeddings and metadata from `/vector_stores/` methods; disallow changing Provider ID (#4046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Updates `/vector_stores/{vector_store_id}/files/{file_id}/content` to allow returning `embeddings` and `metadata` using the `extra_query` - Updates the UI accordingly to display them. - Update UI to support CRUD operations in the Vector Stores section and adds a new modal exposing the functionality. - Updates Vector Store update to fail if a user tries to update Provider ID (which doesn't make sense to allow) ```python In [1]: client.vector_stores.files.content( vector_store_id=vector_store.id, file_id=file.id, extra_query={"include_embeddings": True, "include_metadata": True} ) Out [1]: FileContentResponse(attributes={}, content=[Content(text='This is a test document to check if embeddings are generated properly.\n', type='text', embedding=[0.33760684728622437, ...,], chunk_metadata={'chunk_id': '62a63ae0-c202-f060-1b86-0a688995b8d3', 'document_id': 'file-27291dbc679642ac94ffac6d2810c339', 'source': None, 'created_timestamp': 1762053437, 'updated_timestamp': 1762053437, 'chunk_window': '0-13', 'chunk_tokenizer': 'DEFAULT_TIKTOKEN_TOKENIZER', 'chunk_embedding_model': 'sentence-transformers/nomic -ai/nomic-embed-text-v1.5', 'chunk_embedding_dimension': 768, 'content_token_count': 13, 'metadata_token_count': 9}, metadata={'filename': 'test-embedding.txt', 'chunk_id': '62a63ae0-c202-f060-1b86-0a688995b8d3', 'document_id': 'file-27291dbc679642ac94ffac6d2810c339', 'token_count': 13, 'metadata_token_count': 9})], file_id='file-27291dbc679642ac94ffac6d2810c339', filename='test-embedding.txt') ``` Screenshots of UI are displayed below: ### List Vector Store with Added "Create New Vector Store" Screenshot 2025-11-06 at 10 47
25 PM ### Create New Vector Store Screenshot 2025-11-06 at 10 47
49 PM ### Edit Vector Store Screenshot 2025-11-06 at 10 48
32 PM ### Vector Store Files Contents page (with Embeddings) Screenshot 2025-11-06 at 11 54
32 PM ### Vector Store Files Contents Details page (with Embeddings) Screenshot 2025-11-06 at 11 55
00 PM ## Test Plan Tests added for Middleware extension and Provider failures. --------- Signed-off-by: Francisco Javier Arceo --- client-sdks/stainless/openapi.yml | 40 +- docs/static/llama-stack-spec.yaml | 40 +- docs/static/stainless-llama-stack-spec.yaml | 40 +- src/llama_stack/apis/vector_io/vector_io.py | 46 ++- src/llama_stack/core/library_client.py | 6 + src/llama_stack/core/routers/vector_io.py | 20 +- .../core/routing_tables/vector_stores.py | 5 + .../utils/memory/openai_vector_store_mixin.py | 64 +-- .../app/logs/vector-stores/page.tsx | 386 +++++++++++++++--- .../components/prompts/prompt-editor.test.tsx | 2 +- .../vector-store-detail.test.tsx | 14 + .../vector-stores/vector-store-detail.tsx | 183 ++++++++- .../vector-stores/vector-store-editor.tsx | 235 +++++++++++ src/llama_stack_ui/lib/contents-api.ts | 40 +- .../vector_io/test_openai_vector_stores.py | 95 +++++ tests/unit/core/routers/test_vector_io.py | 62 +++ tests/unit/server/test_sse.py | 8 +- 17 files changed, 1161 insertions(+), 125 deletions(-) create mode 100644 src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 9f3ef15b5..1be4af6c9 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -2691,7 +2691,8 @@ paths: responses: '200': description: >- - A VectorStoreFileContentResponse representing the file contents. + File contents, optionally with embeddings and metadata based on query + parameters. content: application/json: schema: @@ -2726,6 +2727,20 @@ paths: required: true schema: type: string + - name: include_embeddings + in: query + description: >- + Whether to include embedding vectors in the response. + required: false + schema: + $ref: '#/components/schemas/bool' + - name: include_metadata + in: query + description: >- + Whether to include chunk metadata in the response. + required: false + schema: + $ref: '#/components/schemas/bool' deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -10091,6 +10106,8 @@ components: title: VectorStoreFileDeleteResponse description: >- Response from deleting a vector store file. + bool: + type: boolean VectorStoreContent: type: object properties: @@ -10102,6 +10119,26 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: Optional chunk metadata + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Optional user-defined metadata additionalProperties: false required: - type @@ -10125,6 +10162,7 @@ components: description: Parsed content of the file has_more: type: boolean + default: false description: >- Indicates if there are more content pages to fetch next_page: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ce8708b68..66eda78c7 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -2688,7 +2688,8 @@ paths: responses: '200': description: >- - A VectorStoreFileContentResponse representing the file contents. + File contents, optionally with embeddings and metadata based on query + parameters. content: application/json: schema: @@ -2723,6 +2724,20 @@ paths: required: true schema: type: string + - name: include_embeddings + in: query + description: >- + Whether to include embedding vectors in the response. + required: false + schema: + $ref: '#/components/schemas/bool' + - name: include_metadata + in: query + description: >- + Whether to include chunk metadata in the response. + required: false + schema: + $ref: '#/components/schemas/bool' deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -9375,6 +9390,8 @@ components: title: VectorStoreFileDeleteResponse description: >- Response from deleting a vector store file. + bool: + type: boolean VectorStoreContent: type: object properties: @@ -9386,6 +9403,26 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: Optional chunk metadata + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Optional user-defined metadata additionalProperties: false required: - type @@ -9409,6 +9446,7 @@ components: description: Parsed content of the file has_more: type: boolean + default: false description: >- Indicates if there are more content pages to fetch next_page: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 9f3ef15b5..1be4af6c9 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -2691,7 +2691,8 @@ paths: responses: '200': description: >- - A VectorStoreFileContentResponse representing the file contents. + File contents, optionally with embeddings and metadata based on query + parameters. content: application/json: schema: @@ -2726,6 +2727,20 @@ paths: required: true schema: type: string + - name: include_embeddings + in: query + description: >- + Whether to include embedding vectors in the response. + required: false + schema: + $ref: '#/components/schemas/bool' + - name: include_metadata + in: query + description: >- + Whether to include chunk metadata in the response. + required: false + schema: + $ref: '#/components/schemas/bool' deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -10091,6 +10106,8 @@ components: title: VectorStoreFileDeleteResponse description: >- Response from deleting a vector store file. + bool: + type: boolean VectorStoreContent: type: object properties: @@ -10102,6 +10119,26 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: Optional chunk metadata + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Optional user-defined metadata additionalProperties: false required: - type @@ -10125,6 +10162,7 @@ components: description: Parsed content of the file has_more: type: boolean + default: false description: >- Indicates if there are more content pages to fetch next_page: diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 846c6f191..699241128 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -10,7 +10,7 @@ # the root directory of this source tree. from typing import Annotated, Any, Literal, Protocol, runtime_checkable -from fastapi import Body +from fastapi import Body, Query from pydantic import BaseModel, Field from llama_stack.apis.common.tracing import telemetry_traceable @@ -224,10 +224,16 @@ class VectorStoreContent(BaseModel): :param type: Content type, currently only "text" is supported :param text: The actual text content + :param embedding: Optional embedding vector for this content chunk + :param chunk_metadata: Optional chunk metadata + :param metadata: Optional user-defined metadata """ type: Literal["text"] text: str + embedding: list[float] | None = None + chunk_metadata: ChunkMetadata | None = None + metadata: dict[str, Any] | None = None @json_schema_type @@ -280,6 +286,22 @@ class VectorStoreDeleteResponse(BaseModel): deleted: bool = True +@json_schema_type +class VectorStoreFileContentResponse(BaseModel): + """Represents the parsed content of a vector store file. + + :param object: The object type, which is always `vector_store.file_content.page` + :param data: Parsed content of the file + :param has_more: Indicates if there are more content pages to fetch + :param next_page: The token for the next page, if any + """ + + object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page" + data: list[VectorStoreContent] + has_more: bool = False + next_page: str | None = None + + @json_schema_type class VectorStoreChunkingStrategyAuto(BaseModel): """Automatic chunking strategy for vector store files. @@ -395,22 +417,6 @@ class VectorStoreListFilesResponse(BaseModel): has_more: bool = False -@json_schema_type -class VectorStoreFileContentResponse(BaseModel): - """Represents the parsed content of a vector store file. - - :param object: The object type, which is always `vector_store.file_content.page` - :param data: Parsed content of the file - :param has_more: Indicates if there are more content pages to fetch - :param next_page: The token for the next page, if any - """ - - object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page" - data: list[VectorStoreContent] - has_more: bool - next_page: str | None = None - - @json_schema_type class VectorStoreFileDeleteResponse(BaseModel): """Response from deleting a vector store file. @@ -732,12 +738,16 @@ class VectorIO(Protocol): self, vector_store_id: str, file_id: str, + include_embeddings: Annotated[bool | None, Query(default=False)] = False, + include_metadata: Annotated[bool | None, Query(default=False)] = False, ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file. :param vector_store_id: The ID of the vector store containing the file to retrieve. :param file_id: The ID of the file to retrieve. - :returns: A VectorStoreFileContentResponse representing the file contents. + :param include_embeddings: Whether to include embedding vectors in the response. + :param include_metadata: Whether to include chunk metadata in the response. + :returns: File contents, optionally with embeddings and metadata based on query parameters. """ ... diff --git a/src/llama_stack/core/library_client.py b/src/llama_stack/core/library_client.py index b8f9f715f..db990368b 100644 --- a/src/llama_stack/core/library_client.py +++ b/src/llama_stack/core/library_client.py @@ -389,6 +389,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params + # Pass through params that aren't already handled as path params + if options.params: + extra_query_params = {k: v for k, v in options.params.items() if k not in path_params} + if extra_query_params: + body["extra_query"] = extra_query_params + body, field_names = self._handle_file_uploads(options, body) body = self._convert_body(matched_func, body, exclude_params=set(field_names)) diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index 9dac461db..ed5fb8253 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -247,6 +247,13 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") + + # Check if provider_id is being changed (not supported) + if metadata and "provider_id" in metadata: + current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id) + if current_store and current_store.provider_id != metadata["provider_id"]: + raise ValueError("provider_id cannot be changed after vector store creation") + provider = await self.routing_table.get_provider_impl(vector_store_id) return await provider.openai_update_vector_store( vector_store_id=vector_store_id, @@ -338,12 +345,19 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, file_id: str, + include_embeddings: bool | None = False, + include_metadata: bool | None = False, ) -> VectorStoreFileContentResponse: - logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - provider = await self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + logger.debug( + f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, " + f"include_embeddings={include_embeddings}, include_metadata={include_metadata}" + ) + + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, + include_embeddings=include_embeddings, + include_metadata=include_metadata, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack/core/routing_tables/vector_stores.py b/src/llama_stack/core/routing_tables/vector_stores.py index f95a4dbe3..e77739abe 100644 --- a/src/llama_stack/core/routing_tables/vector_stores.py +++ b/src/llama_stack/core/routing_tables/vector_stores.py @@ -195,12 +195,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, file_id: str, + include_embeddings: bool | None = False, + include_metadata: bool | None = False, ) -> VectorStoreFileContentResponse: await self.assert_action_allowed("read", "vector_store", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, + include_embeddings=include_embeddings, + include_metadata=include_metadata, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 86e6ea013..853245598 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -704,34 +704,35 @@ class OpenAIVectorStoreMixin(ABC): # Unknown filter type, default to no match raise ValueError(f"Unsupported filter type: {filter_type}") - def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]: - # content is InterleavedContent + def _chunk_to_vector_store_content( + self, chunk: Chunk, include_embeddings: bool = False, include_metadata: bool = False + ) -> list[VectorStoreContent]: + def extract_fields() -> dict: + """Extract embedding and metadata fields from chunk based on include flags.""" + return { + "embedding": chunk.embedding if include_embeddings else None, + "chunk_metadata": chunk.chunk_metadata if include_metadata else None, + "metadata": chunk.metadata if include_metadata else None, + } + + fields = extract_fields() + if isinstance(chunk.content, str): - content = [ - VectorStoreContent( - type="text", - text=chunk.content, - ) - ] + content_item = VectorStoreContent(type="text", text=chunk.content, **fields) + content = [content_item] elif isinstance(chunk.content, list): # TODO: Add support for other types of content - content = [ - VectorStoreContent( - type="text", - text=item.text, - ) - for item in chunk.content - if item.type == "text" - ] + content = [] + for item in chunk.content: + if item.type == "text": + content_item = VectorStoreContent(type="text", text=item.text, **fields) + content.append(content_item) else: if chunk.content.type != "text": raise ValueError(f"Unsupported content type: {chunk.content.type}") - content = [ - VectorStoreContent( - type="text", - text=chunk.content.text, - ) - ] + + content_item = VectorStoreContent(type="text", text=chunk.content.text, **fields) + content = [content_item] return content async def openai_attach_file_to_vector_store( @@ -820,13 +821,12 @@ class OpenAIVectorStoreMixin(ABC): message=str(e), ) - # Create OpenAI vector store file metadata + # Save vector store file to persistent storage AFTER insert_chunks + # so that chunks include the embeddings that were generated file_info = vector_store_file_object.model_dump(exclude={"last_error"}) file_info["filename"] = file_response.filename if file_response else "" - # Save vector store file to persistent storage (provider-specific) dict_chunks = [c.model_dump() for c in chunks] - # This should be updated to include chunk_id await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks) # Update file_ids and file_counts in vector store metadata @@ -921,21 +921,27 @@ class OpenAIVectorStoreMixin(ABC): self, vector_store_id: str, file_id: str, + include_embeddings: bool | None = False, + include_metadata: bool | None = False, ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file.""" if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) + # Parameters are already provided directly + # include_embeddings and include_metadata are now function parameters + dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] content = [] for chunk in chunks: - content.extend(self._chunk_to_vector_store_content(chunk)) + content.extend( + self._chunk_to_vector_store_content( + chunk, include_embeddings=include_embeddings or False, include_metadata=include_metadata or False + ) + ) return VectorStoreFileContentResponse( - object="vector_store.file_content.page", data=content, - has_more=False, - next_page=None, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack_ui/app/logs/vector-stores/page.tsx b/src/llama_stack_ui/app/logs/vector-stores/page.tsx index 72196d496..84680e01a 100644 --- a/src/llama_stack_ui/app/logs/vector-stores/page.tsx +++ b/src/llama_stack_ui/app/logs/vector-stores/page.tsx @@ -8,6 +8,9 @@ import type { import { useRouter } from "next/navigation"; import { usePagination } from "@/hooks/use-pagination"; import { Button } from "@/components/ui/button"; +import { Plus, Trash2, Search, Edit, X } from "lucide-react"; +import { useState } from "react"; +import { Input } from "@/components/ui/input"; import { Table, TableBody, @@ -17,9 +20,21 @@ import { TableRow, } from "@/components/ui/table"; import { Skeleton } from "@/components/ui/skeleton"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import { + VectorStoreEditor, + VectorStoreFormData, +} from "@/components/vector-stores/vector-store-editor"; export default function VectorStoresPage() { const router = useRouter(); + const client = useAuthClient(); + const [deletingStores, setDeletingStores] = useState>(new Set()); + const [searchTerm, setSearchTerm] = useState(""); + const [showVectorStoreModal, setShowVectorStoreModal] = useState(false); + const [editingStore, setEditingStore] = useState(null); + const [modalError, setModalError] = useState(null); + const [showSuccessState, setShowSuccessState] = useState(false); const { data: stores, status, @@ -47,6 +62,142 @@ export default function VectorStoresPage() { } }, [status, hasMore, loadMore]); + // Handle ESC key to close modal + React.useEffect(() => { + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape" && showVectorStoreModal) { + handleCancel(); + } + }; + + document.addEventListener("keydown", handleEscape); + return () => document.removeEventListener("keydown", handleEscape); + }, [showVectorStoreModal]); + + const handleDeleteVectorStore = async (storeId: string) => { + if ( + !confirm( + "Are you sure you want to delete this vector store? This action cannot be undone." + ) + ) { + return; + } + + setDeletingStores(prev => new Set([...prev, storeId])); + + try { + await client.vectorStores.delete(storeId); + // Reload the data to reflect the deletion + window.location.reload(); + } catch (err: unknown) { + console.error("Failed to delete vector store:", err); + const errorMessage = err instanceof Error ? err.message : "Unknown error"; + alert(`Failed to delete vector store: ${errorMessage}`); + } finally { + setDeletingStores(prev => { + const newSet = new Set(prev); + newSet.delete(storeId); + return newSet; + }); + } + }; + + const handleSaveVectorStore = async (formData: VectorStoreFormData) => { + try { + setModalError(null); + + if (editingStore) { + // Update existing vector store + const updateParams: { + name?: string; + extra_body?: Record; + } = {}; + + // Only include fields that have changed or are provided + if (formData.name && formData.name !== editingStore.name) { + updateParams.name = formData.name; + } + + // Add all parameters to extra_body (except provider_id which can't be changed) + const extraBody: Record = {}; + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + updateParams.extra_body = extraBody; + } + + await client.vectorStores.update(editingStore.id, updateParams); + + // Show success state with close button + setShowSuccessState(true); + setModalError( + "✅ Vector store updated successfully! You can close this modal and refresh the page to see changes." + ); + return; + } + + const createParams: { + name?: string; + provider_id?: string; + extra_body?: Record; + } = { + name: formData.name || undefined, + }; + + // Extract provider_id to top-level (like Python client does) + if (formData.provider_id) { + createParams.provider_id = formData.provider_id; + } + + // Add remaining parameters to extra_body + const extraBody: Record = {}; + if (formData.provider_id) { + extraBody.provider_id = formData.provider_id; + } + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + createParams.extra_body = extraBody; + } + + await client.vectorStores.create(createParams); + + // Show success state with close button + setShowSuccessState(true); + setModalError( + "✅ Vector store created successfully! You can close this modal and refresh the page to see changes." + ); + } catch (err: unknown) { + console.error("Failed to create vector store:", err); + const errorMessage = + err instanceof Error ? err.message : "Failed to create vector store"; + setModalError(errorMessage); + } + }; + + const handleEditVectorStore = (store: VectorStore) => { + setEditingStore(store); + setShowVectorStoreModal(true); + setModalError(null); + }; + + const handleCancel = () => { + setShowVectorStoreModal(false); + setEditingStore(null); + setModalError(null); + setShowSuccessState(false); + }; + const renderContent = () => { if (status === "loading") { return ( @@ -66,73 +217,190 @@ export default function VectorStoresPage() { return

No vector stores found.

; } - return ( -
- - - - ID - Name - Created - Completed - Cancelled - Failed - In Progress - Total - Usage Bytes - Provider ID - Provider Vector DB ID - - - - {stores.map(store => { - const fileCounts = store.file_counts; - const metadata = store.metadata || {}; - const providerId = metadata.provider_id ?? ""; - const providerDbId = metadata.provider_vector_db_id ?? ""; + // Filter stores based on search term + const filteredStores = stores.filter(store => { + if (!searchTerm) return true; - return ( - router.push(`/logs/vector-stores/${store.id}`)} - className="cursor-pointer hover:bg-muted/50" - > - - - - {store.name} - - {new Date(store.created_at * 1000).toLocaleString()} - - {fileCounts.completed} - {fileCounts.cancelled} - {fileCounts.failed} - {fileCounts.in_progress} - {fileCounts.total} - {store.usage_bytes} - {providerId} - {providerDbId} - - ); - })} - -
+ const searchLower = searchTerm.toLowerCase(); + return ( + store.id.toLowerCase().includes(searchLower) || + (store.name && store.name.toLowerCase().includes(searchLower)) || + (store.metadata?.provider_id && + String(store.metadata.provider_id) + .toLowerCase() + .includes(searchLower)) || + (store.metadata?.provider_vector_db_id && + String(store.metadata.provider_vector_db_id) + .toLowerCase() + .includes(searchLower)) + ); + }); + + return ( +
+ {/* Search Bar */} +
+ + setSearchTerm(e.target.value)} + className="pl-10" + /> +
+ +
+ + + + ID + Name + Created + Completed + Cancelled + Failed + In Progress + Total + Usage Bytes + Provider ID + Provider Vector DB ID + Actions + + + + {filteredStores.map(store => { + const fileCounts = store.file_counts; + const metadata = store.metadata || {}; + const providerId = metadata.provider_id ?? ""; + const providerDbId = metadata.provider_vector_db_id ?? ""; + + return ( + + router.push(`/logs/vector-stores/${store.id}`) + } + className="cursor-pointer hover:bg-muted/50" + > + + + + {store.name} + + {new Date(store.created_at * 1000).toLocaleString()} + + {fileCounts.completed} + {fileCounts.cancelled} + {fileCounts.failed} + {fileCounts.in_progress} + {fileCounts.total} + {store.usage_bytes} + {providerId} + {providerDbId} + +
+ + +
+
+
+ ); + })} +
+
+
); }; return (
-

Vector Stores

+
+

Vector Stores

+ +
{renderContent()} + + {/* Create Vector Store Modal */} + {showVectorStoreModal && ( +
+
+
+

+ {editingStore ? "Edit Vector Store" : "Create New Vector Store"} +

+ +
+
+ +
+
+
+ )}
); } diff --git a/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx b/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx index 458a5f942..70e0e4e66 100644 --- a/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx +++ b/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx @@ -2,7 +2,7 @@ import React from "react"; import { render, screen, fireEvent } from "@testing-library/react"; import "@testing-library/jest-dom"; import { PromptEditor } from "./prompt-editor"; -import type { Prompt, PromptFormData } from "./types"; +import type { Prompt } from "./types"; describe("PromptEditor", () => { const mockOnSave = jest.fn(); diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx index 08f90ac0d..78bec8147 100644 --- a/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx +++ b/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx @@ -12,6 +12,20 @@ jest.mock("next/navigation", () => ({ }), })); +// Mock NextAuth +jest.mock("next-auth/react", () => ({ + useSession: () => ({ + data: { + accessToken: "mock-access-token", + user: { + id: "mock-user-id", + email: "test@example.com", + }, + }, + status: "authenticated", + }), +})); + describe("VectorStoreDetailView", () => { const defaultProps = { store: null, diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx index d3d0fa249..f5b6281e7 100644 --- a/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx +++ b/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx @@ -1,16 +1,18 @@ "use client"; import { useRouter } from "next/navigation"; +import { useState, useEffect } from "react"; import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; import { Button } from "@/components/ui/button"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import { Edit2, Trash2, X } from "lucide-react"; import { DetailLoadingView, DetailErrorView, DetailNotFoundView, - DetailLayout, PropertiesCard, PropertyItem, } from "@/components/layout/detail-layout"; @@ -23,6 +25,7 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; +import { VectorStoreEditor, VectorStoreFormData } from "./vector-store-editor"; interface VectorStoreDetailViewProps { store: VectorStore | null; @@ -43,21 +46,122 @@ export function VectorStoreDetailView({ errorFiles, id, }: VectorStoreDetailViewProps) { - const title = "Vector Store Details"; const router = useRouter(); + const client = useAuthClient(); + const [isDeleting, setIsDeleting] = useState(false); + const [showEditModal, setShowEditModal] = useState(false); + const [modalError, setModalError] = useState(null); + const [showSuccessState, setShowSuccessState] = useState(false); + + // Handle ESC key to close modal + useEffect(() => { + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape" && showEditModal) { + handleCancel(); + } + }; + + document.addEventListener("keydown", handleEscape); + return () => document.removeEventListener("keydown", handleEscape); + }, [showEditModal]); const handleFileClick = (fileId: string) => { router.push(`/logs/vector-stores/${id}/files/${fileId}`); }; + const handleEditVectorStore = () => { + setShowEditModal(true); + setModalError(null); + setShowSuccessState(false); + }; + + const handleCancel = () => { + setShowEditModal(false); + setModalError(null); + setShowSuccessState(false); + }; + + const handleSaveVectorStore = async (formData: VectorStoreFormData) => { + try { + setModalError(null); + + // Update existing vector store (same logic as list page) + const updateParams: { + name?: string; + extra_body?: Record; + } = {}; + + // Only include fields that have changed or are provided + if (formData.name && formData.name !== store?.name) { + updateParams.name = formData.name; + } + + // Add all parameters to extra_body (except provider_id which can't be changed) + const extraBody: Record = {}; + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + updateParams.extra_body = extraBody; + } + + await client.vectorStores.update(id, updateParams); + + // Show success state + setShowSuccessState(true); + setModalError( + "✅ Vector store updated successfully! You can close this modal and refresh the page to see changes." + ); + } catch (err: unknown) { + console.error("Failed to update vector store:", err); + const errorMessage = + err instanceof Error ? err.message : "Failed to update vector store"; + setModalError(errorMessage); + } + }; + + const handleDeleteVectorStore = async () => { + if ( + !confirm( + "Are you sure you want to delete this vector store? This action cannot be undone." + ) + ) { + return; + } + + setIsDeleting(true); + + try { + await client.vectorStores.delete(id); + // Redirect to the vector stores list after successful deletion + router.push("/logs/vector-stores"); + } catch (err: unknown) { + console.error("Failed to delete vector store:", err); + const errorMessage = err instanceof Error ? err.message : "Unknown error"; + alert(`Failed to delete vector store: ${errorMessage}`); + } finally { + setIsDeleting(false); + } + }; + if (errorStore) { - return ; + return ( + + ); } if (isLoadingStore) { - return ; + return ; } if (!store) { - return ; + return ; } const mainContent = ( @@ -138,6 +242,73 @@ export function VectorStoreDetailView({ ); return ( - + <> +
+

Vector Store Details

+
+ + +
+
+
+
{mainContent}
+
{sidebar}
+
+ + {/* Edit Vector Store Modal */} + {showEditModal && ( +
+
+
+

Edit Vector Store

+ +
+
+ +
+
+
+ )} + ); } diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx new file mode 100644 index 000000000..719a2a9fd --- /dev/null +++ b/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx @@ -0,0 +1,235 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Card, CardContent } from "@/components/ui/card"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import type { Model } from "llama-stack-client/resources/models"; + +export interface VectorStoreFormData { + name: string; + embedding_model?: string; + embedding_dimension?: number; + provider_id?: string; +} + +interface VectorStoreEditorProps { + onSave: (formData: VectorStoreFormData) => Promise; + onCancel: () => void; + error?: string | null; + initialData?: VectorStoreFormData; + showSuccessState?: boolean; + isEditing?: boolean; +} + +export function VectorStoreEditor({ + onSave, + onCancel, + error, + initialData, + showSuccessState, + isEditing = false, +}: VectorStoreEditorProps) { + const client = useAuthClient(); + const [formData, setFormData] = useState( + initialData || { + name: "", + embedding_model: "", + embedding_dimension: 768, + provider_id: "", + } + ); + const [loading, setLoading] = useState(false); + const [models, setModels] = useState([]); + const [modelsLoading, setModelsLoading] = useState(true); + const [modelsError, setModelsError] = useState(null); + + const embeddingModels = models.filter( + model => model.custom_metadata?.model_type === "embedding" + ); + + useEffect(() => { + const fetchModels = async () => { + try { + setModelsLoading(true); + setModelsError(null); + const modelList = await client.models.list(); + setModels(modelList); + + // Set default embedding model if available + const embeddingModelsList = modelList.filter(model => { + return model.custom_metadata?.model_type === "embedding"; + }); + if (embeddingModelsList.length > 0 && !formData.embedding_model) { + setFormData(prev => ({ + ...prev, + embedding_model: embeddingModelsList[0].id, + })); + } + } catch (err) { + console.error("Failed to load models:", err); + setModelsError( + err instanceof Error ? err.message : "Failed to load models" + ); + } finally { + setModelsLoading(false); + } + }; + + fetchModels(); + }, [client]); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setLoading(true); + + try { + await onSave(formData); + } finally { + setLoading(false); + } + }; + + return ( + + +
+
+ + setFormData({ ...formData, name: e.target.value })} + placeholder="Enter vector store name" + required + /> +
+ +
+ + {modelsLoading ? ( +
+ Loading models... ({models.length} loaded) +
+ ) : modelsError ? ( +
+ Error: {modelsError} +
+ ) : embeddingModels.length === 0 ? ( +
+ No embedding models available ({models.length} total models) +
+ ) : ( + + )} + {formData.embedding_model && ( +

+ Dimension:{" "} + {embeddingModels.find(m => m.id === formData.embedding_model) + ?.custom_metadata?.embedding_dimension || "Unknown"} +

+ )} +
+ +
+ + + setFormData({ + ...formData, + embedding_dimension: parseInt(e.target.value) || 768, + }) + } + placeholder="768" + /> +
+ +
+ + + setFormData({ ...formData, provider_id: e.target.value }) + } + placeholder="e.g., faiss, chroma, sqlite" + disabled={isEditing} + /> + {isEditing && ( +

+ Provider ID cannot be changed after vector store creation +

+ )} +
+ + {error && ( +
+ {error} +
+ )} + +
+ {showSuccessState ? ( + + ) : ( + <> + + + + )} +
+
+
+
+ ); +} diff --git a/src/llama_stack_ui/lib/contents-api.ts b/src/llama_stack_ui/lib/contents-api.ts index f4920f3db..35456faff 100644 --- a/src/llama_stack_ui/lib/contents-api.ts +++ b/src/llama_stack_ui/lib/contents-api.ts @@ -34,9 +34,35 @@ export class ContentsAPI { async getFileContents( vectorStoreId: string, - fileId: string + fileId: string, + includeEmbeddings: boolean = true, + includeMetadata: boolean = true ): Promise { - return this.client.vectorStores.files.content(vectorStoreId, fileId); + try { + // Use query parameters to pass embeddings and metadata flags (OpenAI-compatible pattern) + const extraQuery: Record = {}; + if (includeEmbeddings) { + extraQuery.include_embeddings = true; + } + if (includeMetadata) { + extraQuery.include_metadata = true; + } + + const result = await this.client.vectorStores.files.content( + vectorStoreId, + fileId, + { + query: { + include_embeddings: includeEmbeddings, + include_metadata: includeMetadata, + }, + } + ); + return result; + } catch (error) { + console.error("ContentsAPI.getFileContents error:", error); + throw error; + } } async getContent( @@ -70,11 +96,15 @@ export class ContentsAPI { order?: string; after?: string; before?: string; + includeEmbeddings?: boolean; + includeMetadata?: boolean; } ): Promise { - const fileContents = await this.client.vectorStores.files.content( + const fileContents = await this.getFileContents( vectorStoreId, - fileId + fileId, + options?.includeEmbeddings ?? true, + options?.includeMetadata ?? true ); const contentItems: VectorStoreContentItem[] = []; @@ -82,7 +112,7 @@ export class ContentsAPI { const rawContent = content as Record; // Extract actual fields from the API response - const embedding = rawContent.embedding || undefined; + const embedding = rawContent.embedding as number[] | undefined; const created_timestamp = rawContent.created_timestamp || rawContent.created_at || diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 20f9d2978..1043d4903 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -11,6 +11,7 @@ import pytest from llama_stack_client import BadRequestError from openai import BadRequestError as OpenAIBadRequestError +from llama_stack.apis.files import ExpiresAfter from llama_stack.apis.vector_io import Chunk from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.log import get_logger @@ -1604,3 +1605,97 @@ def test_openai_vector_store_embedding_config_from_metadata( assert "metadata_config_store" in store_names assert "consistent_config_store" in store_names + + +@vector_provider_wrapper +def test_openai_vector_store_file_contents_with_extra_query( + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id +): + """Test that vector store file contents endpoint supports extra_query parameter.""" + skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) + compat_client = compat_client_with_empty_stores + + # Create a vector store + vector_store = compat_client.vector_stores.create( + name="test_extra_query_store", + extra_body={ + "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, + }, + ) + + # Create and attach a file + test_content = b"This is test content for extra_query validation." + with BytesIO(test_content) as file_buffer: + file_buffer.name = "test_extra_query.txt" + file = compat_client.files.create( + file=file_buffer, + purpose="assistants", + expires_after=ExpiresAfter(anchor="created_at", seconds=86400), + ) + + file_attach_response = compat_client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id, + extra_body={"embedding_model": embedding_model_id}, + ) + assert file_attach_response.status == "completed" + + # Wait for processing + time.sleep(2) + + # Test that extra_query parameter is accepted and processed + content_with_extra_query = compat_client.vector_stores.files.content( + vector_store_id=vector_store.id, + file_id=file.id, + extra_query={"include_embeddings": True, "include_metadata": True}, + ) + + # Test without extra_query for comparison + content_without_extra_query = compat_client.vector_stores.files.content( + vector_store_id=vector_store.id, + file_id=file.id, + ) + + # Validate that both calls succeed + assert content_with_extra_query is not None + assert content_without_extra_query is not None + assert len(content_with_extra_query.data) > 0 + assert len(content_without_extra_query.data) > 0 + + # Validate that extra_query parameter is processed correctly + # Both should have the embedding/metadata fields available (may be None based on flags) + first_chunk_with_flags = content_with_extra_query.data[0] + first_chunk_without_flags = content_without_extra_query.data[0] + + # The key validation: extra_query fields are present in the response + # Handle both dict and object responses (different clients may return different formats) + def has_field(obj, field): + if isinstance(obj, dict): + return field in obj + else: + return hasattr(obj, field) + + # Validate that all expected fields are present in both responses + expected_fields = ["embedding", "chunk_metadata", "metadata", "text"] + for field in expected_fields: + assert has_field(first_chunk_with_flags, field), f"Field '{field}' missing from response with extra_query" + assert has_field(first_chunk_without_flags, field), f"Field '{field}' missing from response without extra_query" + + # Validate content is the same + def get_field(obj, field): + if isinstance(obj, dict): + return obj[field] + else: + return getattr(obj, field) + + assert get_field(first_chunk_with_flags, "text") == test_content.decode("utf-8") + assert get_field(first_chunk_without_flags, "text") == test_content.decode("utf-8") + + with_flags_embedding = get_field(first_chunk_with_flags, "embedding") + without_flags_embedding = get_field(first_chunk_without_flags, "embedding") + + # Validate that embeddings are included when requested and excluded when not requested + assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True" + assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list" + assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False" diff --git a/tests/unit/core/routers/test_vector_io.py b/tests/unit/core/routers/test_vector_io.py index dd3246cb3..f9bd84a37 100644 --- a/tests/unit/core/routers/test_vector_io.py +++ b/tests/unit/core/routers/test_vector_io.py @@ -55,3 +55,65 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error with pytest.raises(ValueError, match="Multiple vector_io providers available"): await router.openai_create_vector_store(request) + + +async def test_update_vector_store_provider_id_change_fails(): + """Test that updating a vector store with a different provider_id fails with clear error.""" + mock_routing_table = Mock() + + # Mock an existing vector store with provider_id "faiss" + mock_existing_store = Mock() + mock_existing_store.provider_id = "inline::faiss" + mock_existing_store.identifier = "vs_123" + + mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + + router = VectorIORouter(mock_routing_table) + + # Try to update with different provider_id in metadata - this should fail + with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"): + await router.openai_update_vector_store( + vector_store_id="vs_123", + name="updated_name", + metadata={"provider_id": "inline::sqlite"}, # Different provider_id + ) + + # Verify the existing store was looked up to check provider_id + mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123") + + # Provider should not be called since validation failed + mock_routing_table.get_provider_impl.assert_not_called() + + +async def test_update_vector_store_same_provider_id_succeeds(): + """Test that updating a vector store with the same provider_id succeeds.""" + mock_routing_table = Mock() + + # Mock an existing vector store with provider_id "faiss" + mock_existing_store = Mock() + mock_existing_store.provider_id = "inline::faiss" + mock_existing_store.identifier = "vs_123" + + mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + + router = VectorIORouter(mock_routing_table) + + # Update with same provider_id should succeed + await router.openai_update_vector_store( + vector_store_id="vs_123", + name="updated_name", + metadata={"provider_id": "inline::faiss"}, # Same provider_id + ) + + # Verify the provider update method was called + mock_routing_table.get_provider_impl.assert_called_once_with("vs_123") + provider = await mock_routing_table.get_provider_impl("vs_123") + provider.openai_update_vector_store.assert_called_once_with( + vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"} + ) diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py index f36c8c181..0303a6ded 100644 --- a/tests/unit/server/test_sse.py +++ b/tests/unit/server/test_sse.py @@ -104,12 +104,18 @@ async def test_paginated_response_url_setting(): route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route") - # Mock minimal request + # Mock minimal request with proper state object request = MagicMock() request.scope = {"user_attributes": {}, "principal": ""} request.headers = {} request.body = AsyncMock(return_value=b"") + # Create a simple state object without auto-generating attributes + class MockState: + pass + + request.state = MockState() + result = await route_handler(request) assert isinstance(result, PaginatedResponse) From 94e977c257f70296c73a913dd4218d9119558632 Mon Sep 17 00:00:00 2001 From: Ken Dreyer Date: Wed, 12 Nov 2025 13:04:56 -0500 Subject: [PATCH 14/15] fix(docs): link to test replay-record docs for discoverability (#4134) Help users find the comprehensive integration testing docs by linking to the record-replay documentation. This clarifies that the technical README complements the main docs. --- tests/integration/recordings/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/recordings/README.md b/tests/integration/recordings/README.md index 621a07562..bdf4f532f 100644 --- a/tests/integration/recordings/README.md +++ b/tests/integration/recordings/README.md @@ -2,6 +2,10 @@ This directory contains recorded inference API responses used for deterministic testing without requiring live API access. +For more information, see the +[docs](https://llamastack.github.io/docs/contributing/testing/record-replay). +This README provides more technical information. + ## Structure - `responses/` - JSON files containing request/response pairs for inference operations From 356f37b1bae1e98f23ed8f2dd224973b249827ec Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Wed, 12 Nov 2025 18:13:26 +0000 Subject: [PATCH 15/15] docs: clarify model identification uses provider_model_id not model_id (#4128) Updated documentation to accurately reflect current behavior where models are identified as provider_id/provider_model_id in the system. Changes: o Clarify that model_id is for configuration purposes only o Explain models are accessed as provider_id/provider_model_id o Remove outdated aliasing example that suggested model_id could be used as a custom identifier This corrects the documentation which previously suggested model_id could be used to create friendly aliases, which is not how the code actually works. Signed-off-by: Derek Higgins --- docs/docs/distributions/configuration.mdx | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/docs/distributions/configuration.mdx b/docs/docs/distributions/configuration.mdx index ff50c406a..46ecfa475 100644 --- a/docs/docs/distributions/configuration.mdx +++ b/docs/docs/distributions/configuration.mdx @@ -221,7 +221,15 @@ models: ``` A Model is an instance of a "Resource" (see [Concepts](../concepts/)) and is associated with a specific inference provider (in this case, the provider with identifier `ollama`). This is an instance of a "pre-registered" model. While we always encourage the clients to register models before using them, some Stack servers may come up a list of "already known and available" models. -What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`. +What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. The `model_id` field is provided for configuration purposes but is not used as part of the model identifier. + +**Important:** Models are identified as `provider_id/provider_model_id` in the system and when making API calls. When `provider_model_id` is omitted, the server will set it to be the same as `model_id`. + +Examples: +- Config: `model_id: llama3.2`, `provider_id: ollama`, `provider_model_id: null` + → Access as: `ollama/llama3.2` +- Config: `model_id: my-llama`, `provider_id: vllm-inference`, `provider_model_id: llama-3-2-3b` + → Access as: `vllm-inference/llama-3-2-3b` (the `model_id` is not used in the identifier) If you need to conditionally register a model in the configuration, such as only when specific environment variable(s) are set, this can be accomplished by utilizing a special `__disabled__` string as the default value of an environment variable substitution, as shown below: