diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index d8159be62..9f3ef15b5 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string @@ -998,39 +998,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1065,36 +1032,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1725,32 +1662,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1782,33 +1693,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1897,36 +1781,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1958,33 +1812,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2080,32 +1907,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2137,32 +1938,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -2916,11 +2691,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -3171,7 +2946,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -3228,7 +3003,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -3279,7 +3054,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -3336,7 +3111,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -6280,46 +6055,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -6377,6 +6112,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -6882,6 +6626,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -7240,6 +6989,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7321,6 +7075,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at @@ -9115,61 +8874,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -9345,35 +9049,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -9634,37 +9309,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: @@ -10465,41 +10109,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: @@ -10816,68 +10454,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -10945,47 +10521,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest BenchmarkConfig: type: object properties: @@ -11847,6 +11382,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/docs/docs/deploying/kubernetes_deployment.mdx b/docs/docs/deploying/kubernetes_deployment.mdx index 8ed1e2756..48d08f0db 100644 --- a/docs/docs/deploying/kubernetes_deployment.mdx +++ b/docs/docs/deploying/kubernetes_deployment.mdx @@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem'; # Kubernetes Deployment Guide -Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers both local development with Kind and production deployment on AWS EKS. +Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers deployment using the Kubernetes operator to manage the Llama Stack server with Kind. The vLLM inference server is deployed manually. ## Prerequisites @@ -110,115 +110,176 @@ spec: EOF ``` -### Step 3: Configure Llama Stack +### Step 3: Install Kubernetes Operator -Update your run configuration: - -```yaml -providers: - inference: - - provider_id: vllm - provider_type: remote::vllm - config: - url: http://vllm-server.default.svc.cluster.local:8000/v1 - max_tokens: 4096 - api_token: fake -``` - -Build container image: +Install the Llama Stack Kubernetes operator to manage Llama Stack deployments: ```bash -tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <-service`): + +```bash +# List services to find the service name +kubectl get services | grep llamastack + +# Port forward and test (replace SERVICE_NAME with the actual service name) +kubectl port-forward service/llamastack-vllm-service 8321:8321 +``` + +In another terminal, test the deployment: + +```bash +llama-stack-client --endpoint http://localhost:8321 inference chat-completion --message "hello, what model are you?" ``` ## Troubleshooting -**Check pod status:** +### vLLM Server Issues + +**Check vLLM pod status:** ```bash kubectl get pods -l app.kubernetes.io/name=vllm kubectl logs -l app.kubernetes.io/name=vllm ``` -**Test service connectivity:** +**Test vLLM service connectivity:** ```bash kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://vllm-server:8000/v1/models ``` +### Llama Stack Server Issues + +**Check LlamaStackDistribution status:** +```bash +# Get detailed status +kubectl describe llamastackdistribution llamastack-vllm + +# Check for events +kubectl get events --sort-by='.lastTimestamp' | grep llamastack-vllm +``` + +**Check operator-managed pods:** +```bash +# List all pods managed by the operator +kubectl get pods -l app.kubernetes.io/name=llama-stack + +# Check pod logs (replace POD_NAME with actual pod name) +kubectl logs -l app.kubernetes.io/name=llama-stack +``` + +**Check operator status:** +```bash +# Verify the operator is running +kubectl get pods -n llama-stack-operator-system + +# Check operator logs if issues persist +kubectl logs -n llama-stack-operator-system -l control-plane=controller-manager +``` + +**Verify service connectivity:** +```bash +# Get the service endpoint +kubectl get svc llamastack-vllm-service + +# Test connectivity from within the cluster +kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://llamastack-vllm-service:8321/health +``` + ## Related Resources - **[Deployment Overview](/docs/deploying/)** - Overview of deployment options - **[Distributions](/docs/distributions)** - Understanding Llama Stack distributions - **[Configuration](/docs/distributions/configuration)** - Detailed configuration options +- **[LlamaStack Operator](https://github.com/llamastack/llama-stack-k8s-operator)** - Overview of llama-stack kubernetes operator +- **[LlamaStackDistribution](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md)** - API Spec of the llama-stack operator Custom Resource. diff --git a/docs/docs/distributions/remote_hosted_distro/oci.md b/docs/docs/distributions/remote_hosted_distro/oci.md new file mode 100644 index 000000000..b13cf5f73 --- /dev/null +++ b/docs/docs/distributions/remote_hosted_distro/oci.md @@ -0,0 +1,143 @@ +--- +orphan: true +--- + +# OCI Distribution + +The `llamastack/distribution-oci` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| files | `inline::localfs` | +| inference | `remote::oci` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` | +| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | + + +### Environment Variables + +The following environment variables can be configured: + +- `OCI_AUTH_TYPE`: OCI authentication type (instance_principal or config_file) (default: `instance_principal`) +- `OCI_REGION`: OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1) (default: ``) +- `OCI_COMPARTMENT_OCID`: OCI compartment ID for the Generative AI service (default: ``) +- `OCI_CONFIG_FILE_PATH`: OCI config file path (required if OCI_AUTH_TYPE is config_file) (default: `~/.oci/config`) +- `OCI_CLI_PROFILE`: OCI CLI profile name to use from config file (default: `DEFAULT`) + + +## Prerequisites +### Oracle Cloud Infrastructure Setup + +Before using the OCI Generative AI distribution, ensure you have: + +1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/) +2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy +3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models +4. **Authentication**: Configure authentication using either: + - **Instance Principal** (recommended for cloud-hosted deployments) + - **API Key** (for on-premises or development environments) + +### Authentication Methods + +#### Instance Principal Authentication (Recommended) +Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments. + +Requirements: +- Instance must be running in an Oracle Cloud Infrastructure compartment +- Instance must have appropriate IAM policies to access Generative AI services + +#### API Key Authentication +For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file. + +### Required IAM Policies + +Ensure your OCI user or instance has the following policy statements: + +``` +Allow group to use generative-ai-inference-endpoints in compartment +Allow group to manage generative-ai-inference-endpoints in compartment +``` + +## Supported Services + +### Inference: OCI Generative AI +Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports: + +- **Chat Completions**: Conversational AI with context awareness +- **Text Generation**: Complete prompts and generate text content + +#### Available Models +Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models. + +### Safety: Llama Guard +For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide: +- Content filtering and moderation +- Policy compliance checking +- Harmful content detection + +### Vector Storage: Multiple Options +The distribution supports several vector storage providers: +- **FAISS**: Local in-memory vector search +- **ChromaDB**: Distributed vector database +- **PGVector**: PostgreSQL with vector extensions + +### Additional Services +- **Dataset I/O**: Local filesystem and Hugging Face integration +- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities +- **Evaluation**: Meta reference evaluation framework + +## Running Llama Stack with OCI + +You can run the OCI distribution via Docker or local virtual environment. + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci +``` + +### Configuration Examples + +#### Using Instance Principal (Recommended for Production) +```bash +export OCI_AUTH_TYPE=instance_principal +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1.. +``` + +#### Using API Key Authentication (Development) +```bash +export OCI_AUTH_TYPE=config_file +export OCI_CONFIG_FILE_PATH=~/.oci/config +export OCI_CLI_PROFILE=DEFAULT +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id +``` + +## Regional Endpoints + +OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit: + +https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions + +## Troubleshooting + +### Common Issues + +1. **Authentication Errors**: Verify your OCI credentials and IAM policies +2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region +3. **Permission Denied**: Check compartment permissions and Generative AI service access +4. **Region Unavailable**: Verify the specified region supports Generative AI services + +### Getting Help + +For additional support: +- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) +- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues) diff --git a/docs/docs/providers/inference/remote_oci.mdx b/docs/docs/providers/inference/remote_oci.mdx new file mode 100644 index 000000000..33a201a55 --- /dev/null +++ b/docs/docs/providers/inference/remote_oci.mdx @@ -0,0 +1,41 @@ +--- +description: | + Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. + Provider documentation + https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm +sidebar_label: Remote - Oci +title: remote::oci +--- + +# remote::oci + +## Description + + +Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. +Provider documentation +https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | +| `oci_auth_type` | `` | No | instance_principal | OCI authentication type (must be one of: instance_principal, config_file) | +| `oci_region` | `` | No | us-ashburn-1 | OCI region (e.g., us-ashburn-1) | +| `oci_compartment_id` | `` | No | | OCI compartment ID for the Generative AI service | +| `oci_config_file_path` | `` | No | ~/.oci/config | OCI config file path (required if oci_auth_type is config_file) | +| `oci_config_profile` | `` | No | DEFAULT | OCI config profile (required if oci_auth_type is config_file) | + +## Sample Configuration + +```yaml +oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal} +oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config} +oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT} +oci_region: ${env.OCI_REGION:=us-ashburn-1} +oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=} +``` diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 3bc965eb7..dea2e5bbe 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -13,7 +13,352 @@ info: migration reference only. servers: - url: http://any-hosted-llama-stack.com -paths: {} +paths: + /v1/models: + post: + responses: + '200': + description: A Model. + content: + application/json: + schema: + $ref: '#/components/schemas/Model' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: Register model. + description: >- + Register model. + + Register a model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterModelRequest' + required: true + deprecated: true + /v1/models/{model_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: Unregister model. + description: >- + Unregister model. + + Unregister a model. + parameters: + - name: model_id + in: path + description: >- + The identifier of the model to unregister. + required: true + schema: + type: string + deprecated: true + /v1/scoring-functions: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ScoringFunctions + summary: Register a scoring function. + description: Register a scoring function. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterScoringFunctionRequest' + required: true + deprecated: true + /v1/scoring-functions/{scoring_fn_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ScoringFunctions + summary: Unregister a scoring function. + description: Unregister a scoring function. + parameters: + - name: scoring_fn_id + in: path + description: >- + The ID of the scoring function to unregister. + required: true + schema: + type: string + deprecated: true + /v1/shields: + post: + responses: + '200': + description: A Shield. + content: + application/json: + schema: + $ref: '#/components/schemas/Shield' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + summary: Register a shield. + description: Register a shield. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterShieldRequest' + required: true + deprecated: true + /v1/shields/{identifier}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + summary: Unregister a shield. + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string + deprecated: true + /v1/toolgroups: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolGroups + summary: Register a tool group. + description: Register a tool group. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterToolGroupRequest' + required: true + deprecated: true + /v1/toolgroups/{toolgroup_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolGroups + summary: Unregister a tool group. + description: Unregister a tool group. + parameters: + - name: toolgroup_id + in: path + description: The ID of the tool group to unregister. + required: true + schema: + type: string + deprecated: true + /v1beta/datasets: + post: + responses: + '200': + description: A Dataset. + content: + application/json: + schema: + $ref: '#/components/schemas/Dataset' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Datasets + summary: Register a new dataset. + description: Register a new dataset. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterDatasetRequest' + required: true + deprecated: true + /v1beta/datasets/{dataset_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Datasets + summary: Unregister a dataset by its ID. + description: Unregister a dataset by its ID. + parameters: + - name: dataset_id + in: path + description: The ID of the dataset to unregister. + required: true + schema: + type: string + deprecated: true + /v1alpha/eval/benchmarks: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Benchmarks + summary: Register a benchmark. + description: Register a benchmark. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterBenchmarkRequest' + required: true + deprecated: true + /v1alpha/eval/benchmarks/{benchmark_id}: + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Benchmarks + summary: Unregister a benchmark. + description: Unregister a benchmark. + parameters: + - name: benchmark_id + in: path + description: The ID of the benchmark to unregister. + required: true + schema: + type: string + deprecated: true jsonSchemaDialect: >- https://json-schema.org/draft/2020-12/schema components: @@ -46,6 +391,730 @@ components: title: Error description: >- Error response from the API. Roughly follows RFC 7807. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. + RegisterModelRequest: + type: object + properties: + model_id: + type: string + description: The identifier of the model to register. + provider_model_id: + type: string + description: >- + The identifier of the model in the provider. + provider_id: + type: string + description: The identifier of the provider. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Any additional metadata for this model. + model_type: + $ref: '#/components/schemas/ModelType' + description: The type of model to register. + additionalProperties: false + required: + - model_id + title: RegisterModelRequest + Model: + type: object + properties: + identifier: + type: string + description: >- + Unique identifier for this resource in llama stack + provider_resource_id: + type: string + description: >- + Unique identifier for this resource in the provider + provider_id: + type: string + description: >- + ID of the provider that owns this resource + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: model + default: model + description: >- + The resource type, always 'model' for model resources + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Any additional metadata for this model + model_type: + $ref: '#/components/schemas/ModelType' + default: llm + description: >- + The type of model (LLM or embedding model) + additionalProperties: false + required: + - identifier + - provider_id + - type + - metadata + - model_type + title: Model + description: >- + A model resource representing an AI model registered in Llama Stack. + AggregationFunctionType: + type: string + enum: + - average + - weighted_average + - median + - categorical_count + - accuracy + title: AggregationFunctionType + description: >- + Types of aggregation functions for scoring results. + ArrayType: + type: object + properties: + type: + type: string + const: array + default: array + description: Discriminator type. Always "array" + additionalProperties: false + required: + - type + title: ArrayType + description: Parameter type for array values. + BasicScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: basic + default: basic + description: >- + The type of scoring function parameters, always basic + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - aggregation_functions + title: BasicScoringFnParams + description: >- + Parameters for basic scoring function configuration. + BooleanType: + type: object + properties: + type: + type: string + const: boolean + default: boolean + description: Discriminator type. Always "boolean" + additionalProperties: false + required: + - type + title: BooleanType + description: Parameter type for boolean values. + ChatCompletionInputType: + type: object + properties: + type: + type: string + const: chat_completion_input + default: chat_completion_input + description: >- + Discriminator type. Always "chat_completion_input" + additionalProperties: false + required: + - type + title: ChatCompletionInputType + description: >- + Parameter type for chat completion input. + CompletionInputType: + type: object + properties: + type: + type: string + const: completion_input + default: completion_input + description: >- + Discriminator type. Always "completion_input" + additionalProperties: false + required: + - type + title: CompletionInputType + description: Parameter type for completion input. + JsonType: + type: object + properties: + type: + type: string + const: json + default: json + description: Discriminator type. Always "json" + additionalProperties: false + required: + - type + title: JsonType + description: Parameter type for JSON values. + LLMAsJudgeScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: llm_as_judge + default: llm_as_judge + description: >- + The type of scoring function parameters, always llm_as_judge + judge_model: + type: string + description: >- + Identifier of the LLM model to use as a judge for scoring + prompt_template: + type: string + description: >- + (Optional) Custom prompt template for the judge model + judge_score_regexes: + type: array + items: + type: string + description: >- + Regexes to extract the answer from generated response + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - judge_model + - judge_score_regexes + - aggregation_functions + title: LLMAsJudgeScoringFnParams + description: >- + Parameters for LLM-as-judge scoring function configuration. + NumberType: + type: object + properties: + type: + type: string + const: number + default: number + description: Discriminator type. Always "number" + additionalProperties: false + required: + - type + title: NumberType + description: Parameter type for numeric values. + ObjectType: + type: object + properties: + type: + type: string + const: object + default: object + description: Discriminator type. Always "object" + additionalProperties: false + required: + - type + title: ObjectType + description: Parameter type for object values. + ParamType: + oneOf: + - $ref: '#/components/schemas/StringType' + - $ref: '#/components/schemas/NumberType' + - $ref: '#/components/schemas/BooleanType' + - $ref: '#/components/schemas/ArrayType' + - $ref: '#/components/schemas/ObjectType' + - $ref: '#/components/schemas/JsonType' + - $ref: '#/components/schemas/UnionType' + - $ref: '#/components/schemas/ChatCompletionInputType' + - $ref: '#/components/schemas/CompletionInputType' + discriminator: + propertyName: type + mapping: + string: '#/components/schemas/StringType' + number: '#/components/schemas/NumberType' + boolean: '#/components/schemas/BooleanType' + array: '#/components/schemas/ArrayType' + object: '#/components/schemas/ObjectType' + json: '#/components/schemas/JsonType' + union: '#/components/schemas/UnionType' + chat_completion_input: '#/components/schemas/ChatCompletionInputType' + completion_input: '#/components/schemas/CompletionInputType' + RegexParserScoringFnParams: + type: object + properties: + type: + $ref: '#/components/schemas/ScoringFnParamsType' + const: regex_parser + default: regex_parser + description: >- + The type of scoring function parameters, always regex_parser + parsing_regexes: + type: array + items: + type: string + description: >- + Regex to extract the answer from generated response + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row + additionalProperties: false + required: + - type + - parsing_regexes + - aggregation_functions + title: RegexParserScoringFnParams + description: >- + Parameters for regex parser scoring function configuration. + ScoringFnParams: + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' + discriminator: + propertyName: type + mapping: + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + basic: '#/components/schemas/BasicScoringFnParams' + ScoringFnParamsType: + type: string + enum: + - llm_as_judge + - regex_parser + - basic + title: ScoringFnParamsType + description: >- + Types of scoring function parameter configurations. + StringType: + type: object + properties: + type: + type: string + const: string + default: string + description: Discriminator type. Always "string" + additionalProperties: false + required: + - type + title: StringType + description: Parameter type for string values. + UnionType: + type: object + properties: + type: + type: string + const: union + default: union + description: Discriminator type. Always "union" + additionalProperties: false + required: + - type + title: UnionType + description: Parameter type for union values. + RegisterScoringFunctionRequest: + type: object + properties: + scoring_fn_id: + type: string + description: >- + The ID of the scoring function to register. + description: + type: string + description: The description of the scoring function. + return_type: + $ref: '#/components/schemas/ParamType' + description: The return type of the scoring function. + provider_scoring_fn_id: + type: string + description: >- + The ID of the provider scoring function to use for the scoring function. + provider_id: + type: string + description: >- + The ID of the provider to use for the scoring function. + params: + $ref: '#/components/schemas/ScoringFnParams' + description: >- + The parameters for the scoring function for benchmark eval, these can + be overridden for app eval. + additionalProperties: false + required: + - scoring_fn_id + - description + - return_type + title: RegisterScoringFunctionRequest + RegisterShieldRequest: + type: object + properties: + shield_id: + type: string + description: >- + The identifier of the shield to register. + provider_shield_id: + type: string + description: >- + The identifier of the shield in the provider. + provider_id: + type: string + description: The identifier of the provider. + params: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The parameters of the shield. + additionalProperties: false + required: + - shield_id + title: RegisterShieldRequest + Shield: + type: object + properties: + identifier: + type: string + provider_resource_id: + type: string + provider_id: + type: string + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: shield + default: shield + description: The resource type, always shield + params: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Configuration parameters for the shield + additionalProperties: false + required: + - identifier + - provider_id + - type + title: Shield + description: >- + A safety shield resource that can be used to check content. + URL: + type: object + properties: + uri: + type: string + description: The URL string pointing to the resource + additionalProperties: false + required: + - uri + title: URL + description: A URL reference to external content. + RegisterToolGroupRequest: + type: object + properties: + toolgroup_id: + type: string + description: The ID of the tool group to register. + provider_id: + type: string + description: >- + The ID of the provider to use for the tool group. + mcp_endpoint: + $ref: '#/components/schemas/URL' + description: >- + The MCP endpoint to use for the tool group. + args: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + A dictionary of arguments to pass to the tool group. + additionalProperties: false + required: + - toolgroup_id + - provider_id + title: RegisterToolGroupRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RowsDataSource: + type: object + properties: + type: + type: string + const: rows + default: rows + rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user", + "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, + world!"}]} ] + additionalProperties: false + required: + - type + - rows + title: RowsDataSource + description: A dataset stored in rows. + URIDataSource: + type: object + properties: + type: + type: string + const: uri + default: uri + uri: + type: string + description: >- + The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}" + additionalProperties: false + required: + - type + - uri + title: URIDataSource + description: >- + A dataset that can be obtained from a URI. + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + Dataset: + type: object + properties: + identifier: + type: string + provider_resource_id: + type: string + provider_id: + type: string + type: + type: string + enum: + - model + - shield + - vector_store + - dataset + - scoring_function + - benchmark + - tool + - tool_group + - prompt + const: dataset + default: dataset + description: >- + Type of resource, always 'dataset' for datasets + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + Purpose of the dataset indicating its intended use + source: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + description: >- + Data source configuration for the dataset + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Additional metadata for the dataset + additionalProperties: false + required: + - identifier + - provider_id + - type + - purpose + - source + - metadata + title: Dataset + description: >- + Dataset resource for storing and accessing training or evaluation data. + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed @@ -93,4 +1162,25 @@ components: detail: An unexpected error occurred security: - Default: [] -tags: [] +tags: + - name: Benchmarks + description: '' + - name: Datasets + description: '' + - name: Models + description: '' + - name: ScoringFunctions + description: '' + - name: Shields + description: '' + - name: ToolGroups + description: '' +x-tagGroups: + - name: Operations + tags: + - Benchmarks + - Datasets + - Models + - ScoringFunctions + - Shields + - ToolGroups diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 68e2f59be..6f379d17c 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -162,7 +162,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -219,7 +219,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -270,7 +270,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -327,7 +327,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -936,68 +936,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -1065,47 +1003,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest AggregationFunctionType: type: string enum: @@ -2254,6 +2151,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ea7fd6eec..ce8708b68 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -960,7 +960,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string @@ -995,39 +995,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1062,36 +1029,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1722,32 +1659,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1779,33 +1690,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1894,36 +1778,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1955,33 +1809,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2077,32 +1904,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2134,32 +1935,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -2913,11 +2688,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -5564,46 +5339,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -5661,6 +5396,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -6166,6 +5910,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -6524,6 +6273,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -6605,6 +6359,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at @@ -8399,61 +8158,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -8629,35 +8333,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -8918,37 +8593,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: @@ -9749,41 +9393,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index d8159be62..9f3ef15b5 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string @@ -998,39 +998,6 @@ paths: description: List models using the OpenAI API. parameters: [] deprecated: false - post: - responses: - '200': - description: A Model. - content: - application/json: - schema: - $ref: '#/components/schemas/Model' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Register model. - description: >- - Register model. - - Register a model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterModelRequest' - required: true - deprecated: false /v1/models/{model_id}: get: responses: @@ -1065,36 +1032,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: Unregister model. - description: >- - Unregister model. - - Unregister a model. - parameters: - - name: model_id - in: path - description: >- - The identifier of the model to unregister. - required: true - schema: - type: string - deprecated: false /v1/moderations: post: responses: @@ -1725,32 +1662,6 @@ paths: description: List all scoring functions. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Register a scoring function. - description: Register a scoring function. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterScoringFunctionRequest' - required: true - deprecated: false /v1/scoring-functions/{scoring_fn_id}: get: responses: @@ -1782,33 +1693,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ScoringFunctions - summary: Unregister a scoring function. - description: Unregister a scoring function. - parameters: - - name: scoring_fn_id - in: path - description: >- - The ID of the scoring function to unregister. - required: true - schema: - type: string - deprecated: false /v1/scoring/score: post: responses: @@ -1897,36 +1781,6 @@ paths: description: List all shields. parameters: [] deprecated: false - post: - responses: - '200': - description: A Shield. - content: - application/json: - schema: - $ref: '#/components/schemas/Shield' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Register a shield. - description: Register a shield. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterShieldRequest' - required: true - deprecated: false /v1/shields/{identifier}: get: responses: @@ -1958,33 +1812,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Shields - summary: Unregister a shield. - description: Unregister a shield. - parameters: - - name: identifier - in: path - description: >- - The identifier of the shield to unregister. - required: true - schema: - type: string - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2080,32 +1907,6 @@ paths: description: List tool groups with optional provider. parameters: [] deprecated: false - post: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Register a tool group. - description: Register a tool group. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterToolGroupRequest' - required: true - deprecated: false /v1/toolgroups/{toolgroup_id}: get: responses: @@ -2137,32 +1938,6 @@ paths: schema: type: string deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - ToolGroups - summary: Unregister a tool group. - description: Unregister a tool group. - parameters: - - name: toolgroup_id - in: path - description: The ID of the tool group to unregister. - required: true - schema: - type: string - deprecated: false /v1/tools: get: responses: @@ -2916,11 +2691,11 @@ paths: responses: '200': description: >- - A list of InterleavedContent representing the file contents. + A VectorStoreFileContentResponse representing the file contents. content: application/json: schema: - $ref: '#/components/schemas/VectorStoreFileContentsResponse' + $ref: '#/components/schemas/VectorStoreFileContentResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -3171,7 +2946,7 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - deprecated: false + deprecated: true /v1beta/datasets/{dataset_id}: get: responses: @@ -3228,7 +3003,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks: get: responses: @@ -3279,7 +3054,7 @@ paths: schema: $ref: '#/components/schemas/RegisterBenchmarkRequest' required: true - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}: get: responses: @@ -3336,7 +3111,7 @@ paths: required: true schema: type: string - deprecated: false + deprecated: true /v1alpha/eval/benchmarks/{benchmark_id}/evaluations: post: responses: @@ -6280,46 +6055,6 @@ components: required: - data title: OpenAIListModelsResponse - ModelType: - type: string - enum: - - llm - - embedding - - rerank - title: ModelType - description: >- - Enumeration of supported model types in Llama Stack. - RegisterModelRequest: - type: object - properties: - model_id: - type: string - description: The identifier of the model to register. - provider_model_id: - type: string - description: >- - The identifier of the model in the provider. - provider_id: - type: string - description: The identifier of the provider. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: Any additional metadata for this model. - model_type: - $ref: '#/components/schemas/ModelType' - description: The type of model to register. - additionalProperties: false - required: - - model_id - title: RegisterModelRequest Model: type: object properties: @@ -6377,6 +6112,15 @@ components: title: Model description: >- A model resource representing an AI model registered in Llama Stack. + ModelType: + type: string + enum: + - llm + - embedding + - rerank + title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. RunModerationRequest: type: object properties: @@ -6882,6 +6626,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response input: type: array items: @@ -7240,6 +6989,11 @@ components: (Optional) Additional fields to include in the response. max_infer_iters: type: integer + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response. additionalProperties: false required: - input @@ -7321,6 +7075,11 @@ components: type: string description: >- (Optional) System message inserted into the model's context + max_tool_calls: + type: integer + description: >- + (Optional) Max number of total calls to built-in tools that can be processed + in a response additionalProperties: false required: - created_at @@ -9115,61 +8874,6 @@ components: required: - data title: ListScoringFunctionsResponse - ParamType: - oneOf: - - $ref: '#/components/schemas/StringType' - - $ref: '#/components/schemas/NumberType' - - $ref: '#/components/schemas/BooleanType' - - $ref: '#/components/schemas/ArrayType' - - $ref: '#/components/schemas/ObjectType' - - $ref: '#/components/schemas/JsonType' - - $ref: '#/components/schemas/UnionType' - - $ref: '#/components/schemas/ChatCompletionInputType' - - $ref: '#/components/schemas/CompletionInputType' - discriminator: - propertyName: type - mapping: - string: '#/components/schemas/StringType' - number: '#/components/schemas/NumberType' - boolean: '#/components/schemas/BooleanType' - array: '#/components/schemas/ArrayType' - object: '#/components/schemas/ObjectType' - json: '#/components/schemas/JsonType' - union: '#/components/schemas/UnionType' - chat_completion_input: '#/components/schemas/ChatCompletionInputType' - completion_input: '#/components/schemas/CompletionInputType' - RegisterScoringFunctionRequest: - type: object - properties: - scoring_fn_id: - type: string - description: >- - The ID of the scoring function to register. - description: - type: string - description: The description of the scoring function. - return_type: - $ref: '#/components/schemas/ParamType' - description: The return type of the scoring function. - provider_scoring_fn_id: - type: string - description: >- - The ID of the provider scoring function to use for the scoring function. - provider_id: - type: string - description: >- - The ID of the provider to use for the scoring function. - params: - $ref: '#/components/schemas/ScoringFnParams' - description: >- - The parameters for the scoring function for benchmark eval, these can - be overridden for app eval. - additionalProperties: false - required: - - scoring_fn_id - - description - - return_type - title: RegisterScoringFunctionRequest ScoreRequest: type: object properties: @@ -9345,35 +9049,6 @@ components: required: - data title: ListShieldsResponse - RegisterShieldRequest: - type: object - properties: - shield_id: - type: string - description: >- - The identifier of the shield to register. - provider_shield_id: - type: string - description: >- - The identifier of the shield in the provider. - provider_id: - type: string - description: The identifier of the provider. - params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The parameters of the shield. - additionalProperties: false - required: - - shield_id - title: RegisterShieldRequest InvokeToolRequest: type: object properties: @@ -9634,37 +9309,6 @@ components: title: ListToolGroupsResponse description: >- Response containing a list of tool groups. - RegisterToolGroupRequest: - type: object - properties: - toolgroup_id: - type: string - description: The ID of the tool group to register. - provider_id: - type: string - description: >- - The ID of the provider to use for the tool group. - mcp_endpoint: - $ref: '#/components/schemas/URL' - description: >- - The MCP endpoint to use for the tool group. - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool group. - additionalProperties: false - required: - - toolgroup_id - - provider_id - title: RegisterToolGroupRequest Chunk: type: object properties: @@ -10465,41 +10109,35 @@ components: title: VectorStoreContent description: >- Content item from a vector store file or search result. - VectorStoreFileContentsResponse: + VectorStoreFileContentResponse: type: object properties: - file_id: + object: type: string - description: Unique identifier for the file - filename: - type: string - description: Name of the file - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: vector_store.file_content.page + default: vector_store.file_content.page description: >- - Key-value attributes associated with the file - content: + The object type, which is always `vector_store.file_content.page` + data: type: array items: $ref: '#/components/schemas/VectorStoreContent' - description: List of content items from the file + description: Parsed content of the file + has_more: + type: boolean + description: >- + Indicates if there are more content pages to fetch + next_page: + type: string + description: The token for the next page, if any additionalProperties: false required: - - file_id - - filename - - attributes - - content - title: VectorStoreFileContentsResponse + - object + - data + - has_more + title: VectorStoreFileContentResponse description: >- - Response from retrieving the contents of a vector store file. + Represents the parsed content of a vector store file. OpenaiSearchVectorStoreRequest: type: object properties: @@ -10816,68 +10454,6 @@ components: - data title: ListDatasetsResponse description: Response from listing datasets. - DataSource: - oneOf: - - $ref: '#/components/schemas/URIDataSource' - - $ref: '#/components/schemas/RowsDataSource' - discriminator: - propertyName: type - mapping: - uri: '#/components/schemas/URIDataSource' - rows: '#/components/schemas/RowsDataSource' - RegisterDatasetRequest: - type: object - properties: - purpose: - type: string - enum: - - post-training/messages - - eval/question-answer - - eval/messages-answer - description: >- - The purpose of the dataset. One of: - "post-training/messages": The dataset - contains a messages column with list of messages for post-training. { - "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", - "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset - contains a question column and an answer column for evaluation. { "question": - "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": - The dataset contains a messages column with list of messages and an answer - column for evaluation. { "messages": [ {"role": "user", "content": "Hello, - my name is John Doe."}, {"role": "assistant", "content": "Hello, John - Doe. How can I help you today?"}, {"role": "user", "content": "What's - my name?"}, ], "answer": "John Doe" } - source: - $ref: '#/components/schemas/DataSource' - description: >- - The data source of the dataset. Ensure that the data source schema is - compatible with the purpose of the dataset. Examples: - { "type": "uri", - "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": - "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" - } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" - } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": - "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] - } ] } - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - The metadata for the dataset. - E.g. {"description": "My dataset"}. - dataset_id: - type: string - description: >- - The ID of the dataset. If not provided, an ID will be generated. - additionalProperties: false - required: - - purpose - - source - title: RegisterDatasetRequest Benchmark: type: object properties: @@ -10945,47 +10521,6 @@ components: required: - data title: ListBenchmarksResponse - RegisterBenchmarkRequest: - type: object - properties: - benchmark_id: - type: string - description: The ID of the benchmark to register. - dataset_id: - type: string - description: >- - The ID of the dataset to use for the benchmark. - scoring_functions: - type: array - items: - type: string - description: >- - The scoring functions to use for the benchmark. - provider_benchmark_id: - type: string - description: >- - The ID of the provider benchmark to use for the benchmark. - provider_id: - type: string - description: >- - The ID of the provider to use for the benchmark. - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: The metadata to use for the benchmark. - additionalProperties: false - required: - - benchmark_id - - dataset_id - - scoring_functions - title: RegisterBenchmarkRequest BenchmarkConfig: type: object properties: @@ -11847,6 +11382,109 @@ components: - hyperparam_search_config - logger_config title: SupervisedFineTuneRequest + DataSource: + oneOf: + - $ref: '#/components/schemas/URIDataSource' + - $ref: '#/components/schemas/RowsDataSource' + discriminator: + propertyName: type + mapping: + uri: '#/components/schemas/URIDataSource' + rows: '#/components/schemas/RowsDataSource' + RegisterDatasetRequest: + type: object + properties: + purpose: + type: string + enum: + - post-training/messages + - eval/question-answer + - eval/messages-answer + description: >- + The purpose of the dataset. One of: - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant", + "content": "Hello, world!"}, ] } - "eval/question-answer": The dataset + contains a question column and an answer column for evaluation. { "question": + "What is the capital of France?", "answer": "Paris" } - "eval/messages-answer": + The dataset contains a messages column with list of messages and an answer + column for evaluation. { "messages": [ {"role": "user", "content": "Hello, + my name is John Doe."}, {"role": "assistant", "content": "Hello, John + Doe. How can I help you today?"}, {"role": "user", "content": "What's + my name?"}, ], "answer": "John Doe" } + source: + $ref: '#/components/schemas/DataSource' + description: >- + The data source of the dataset. Ensure that the data source schema is + compatible with the purpose of the dataset. Examples: - { "type": "uri", + "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" + } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" + } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] + } ] } + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + The metadata for the dataset. - E.g. {"description": "My dataset"}. + dataset_id: + type: string + description: >- + The ID of the dataset. If not provided, an ID will be generated. + additionalProperties: false + required: + - purpose + - source + title: RegisterDatasetRequest + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + description: The ID of the benchmark to register. + dataset_id: + type: string + description: >- + The ID of the dataset to use for the benchmark. + scoring_functions: + type: array + items: + type: string + description: >- + The scoring functions to use for the benchmark. + provider_benchmark_id: + type: string + description: >- + The ID of the provider benchmark to use for the benchmark. + provider_id: + type: string + description: >- + The ID of the provider to use for the benchmark. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: The metadata to use for the benchmark. + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest responses: BadRequest400: description: The request was invalid or malformed diff --git a/pyproject.toml b/pyproject.toml index 4ec83249c..653c6d613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -298,6 +298,7 @@ exclude = [ "^src/llama_stack/providers/remote/agents/sample/", "^src/llama_stack/providers/remote/datasetio/huggingface/", "^src/llama_stack/providers/remote/datasetio/nvidia/", + "^src/llama_stack/providers/remote/inference/oci/", "^src/llama_stack/providers/remote/inference/bedrock/", "^src/llama_stack/providers/remote/inference/nvidia/", "^src/llama_stack/providers/remote/inference/passthrough/", diff --git a/src/llama_stack/apis/agents/agents.py b/src/llama_stack/apis/agents/agents.py index cadef2edc..09687ef33 100644 --- a/src/llama_stack/apis/agents/agents.py +++ b/src/llama_stack/apis/agents/agents.py @@ -87,6 +87,7 @@ class Agents(Protocol): "List of guardrails to apply during response generation. Guardrails provide safety and content moderation." ), ] = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a model response. @@ -97,6 +98,7 @@ class Agents(Protocol): :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param include: (Optional) Additional fields to include in the response. :param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications. + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response. :returns: An OpenAIResponseObject. """ ... diff --git a/src/llama_stack/apis/agents/openai_responses.py b/src/llama_stack/apis/agents/openai_responses.py index a38d1cba6..16657ab32 100644 --- a/src/llama_stack/apis/agents/openai_responses.py +++ b/src/llama_stack/apis/agents/openai_responses.py @@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel): :param truncation: (Optional) Truncation strategy applied to the response :param usage: (Optional) Token usage information for the response :param instructions: (Optional) System message inserted into the model's context + :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response """ created_at: int @@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel): truncation: str | None = None usage: OpenAIResponseUsage | None = None instructions: str | None = None + max_tool_calls: int | None = None @json_schema_type diff --git a/src/llama_stack/apis/benchmarks/benchmarks.py b/src/llama_stack/apis/benchmarks/benchmarks.py index 933205489..9a67269c3 100644 --- a/src/llama_stack/apis/benchmarks/benchmarks.py +++ b/src/llama_stack/apis/benchmarks/benchmarks.py @@ -74,7 +74,7 @@ class Benchmarks(Protocol): """ ... - @webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA) + @webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True) async def register_benchmark( self, benchmark_id: str, @@ -95,7 +95,7 @@ class Benchmarks(Protocol): """ ... - @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA) + @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True) async def unregister_benchmark(self, benchmark_id: str) -> None: """Unregister a benchmark. diff --git a/src/llama_stack/apis/datasets/datasets.py b/src/llama_stack/apis/datasets/datasets.py index ed4ecec22..9bedc6209 100644 --- a/src/llama_stack/apis/datasets/datasets.py +++ b/src/llama_stack/apis/datasets/datasets.py @@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel): class Datasets(Protocol): - @webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA) + @webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True) async def register_dataset( self, purpose: DatasetPurpose, @@ -235,7 +235,7 @@ class Datasets(Protocol): """ ... - @webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA) + @webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True) async def unregister_dataset( self, dataset_id: str, diff --git a/src/llama_stack/apis/inference/event_logger.py b/src/llama_stack/apis/inference/event_logger.py deleted file mode 100644 index d97ece6d4..000000000 --- a/src/llama_stack/apis/inference/event_logger.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from termcolor import cprint - -from llama_stack.apis.inference import ( - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, -) - - -class LogEvent: - def __init__( - self, - content: str = "", - end: str = "\n", - color="white", - ): - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def print(self, flush=True): - cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) - - -class EventLogger: - async def log(self, event_generator): - async for chunk in event_generator: - if isinstance(chunk, ChatCompletionResponseStreamChunk): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") - else: - yield LogEvent("Assistant> ", color="cyan", end="") - yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index 1a865ce5f..9f04917c9 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from enum import Enum +from enum import Enum, StrEnum from typing import ( Annotated, Any, @@ -15,28 +15,18 @@ from typing import ( ) from fastapi import Body -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import TypedDict -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.responses import ( + Order, +) from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -register_schema(ToolCall) -register_schema(ToolDefinition) - -from enum import StrEnum - @json_schema_type class GreedySamplingStrategy(BaseModel): @@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel): content: InterleavedContent -@json_schema_type -class CompletionMessage(BaseModel): - """A message containing the model's (assistant) response in a chat conversation. - - :param role: Must be "assistant" to identify this as the model's response - :param content: The content of the model's response - :param stop_reason: Reason why the model stopped generating. Options are: - - `StopReason.end_of_turn`: The model finished generating the entire response. - - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - - `StopReason.out_of_tokens`: The model ran out of token budget. - :param tool_calls: List of tool calls. Each tool call is a ToolCall object. - """ - - role: Literal["assistant"] = "assistant" - content: InterleavedContent - stop_reason: StopReason - tool_calls: list[ToolCall] | None = Field(default_factory=lambda: []) - - -Message = Annotated[ - UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage, - Field(discriminator="role"), -] -register_schema(Message, name="Message") - - -@json_schema_type -class ToolResponse(BaseModel): - """Response from a tool invocation. - - :param call_id: Unique identifier for the tool call this response is for - :param tool_name: Name of the tool that was invoked - :param content: The response content from the tool - :param metadata: (Optional) Additional metadata about the tool response - """ - - call_id: str - tool_name: BuiltinTool | str - content: InterleavedContent - metadata: dict[str, Any] | None = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ChatCompletionResponseEvent(BaseModel): - """An event during chat completion generation. - - :param event_type: Type of the event - :param delta: Content generated since last event. This can be one or more tokens, or a tool call. - :param logprobs: Optional log probabilities for generated tokens - :param stop_reason: Optional reason why generation stopped, if complete - """ - - event_type: ChatCompletionResponseEventType - delta: ContentDelta - logprobs: list[TokenLogProbs] | None = None - stop_reason: StopReason | None = None - - class ResponseFormatType(StrEnum): """Types of formats for structured (guided) decoding. @@ -357,34 +279,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. - - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ - - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None - - -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. - - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ - - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. @@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum): replace = "replace" -@json_schema_type -class ToolConfig(BaseModel): - """Configuration for tool use. - - :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. - :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - :param system_message_behavior: (Optional) Config for how to override the default system prompt. - - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string - '{{function_definitions}}' to indicate where the function definitions should be inserted. - """ - - tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) - tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) - - def model_post_init(self, __context: Any) -> None: - if isinstance(self.tool_choice, str): - try: - self.tool_choice = ToolChoice[self.tool_choice] - except KeyError: - pass - - -# This is an internally used class -@json_schema_type -class ChatCompletionRequest(BaseModel): - model: str - messages: list[Message] - sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - - tools: list[ToolDefinition] | None = Field(default_factory=lambda: []) - tool_config: ToolConfig | None = Field(default_factory=ToolConfig) - - response_format: ResponseFormat | None = None - stream: bool | None = False - logprobs: LogProbConfig | None = None - - -@json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed chat completion response. - - :param event: The event containing the new content - """ - - event: ChatCompletionResponseEvent - - -@json_schema_type -class ChatCompletionResponse(MetricResponseMixin): - """Response from a chat completion request. - - :param completion_message: The complete response message - :param logprobs: Optional log probabilities for generated tokens - """ - - completion_message: CompletionMessage - logprobs: list[TokenLogProbs] | None = None - - @json_schema_type class EmbeddingsResponse(BaseModel): """Response containing generated embeddings. diff --git a/src/llama_stack/apis/inspect/inspect.py b/src/llama_stack/apis/inspect/inspect.py index 4e0e2548b..235abb124 100644 --- a/src/llama_stack/apis/inspect/inspect.py +++ b/src/llama_stack/apis/inspect/inspect.py @@ -76,7 +76,7 @@ class Inspect(Protocol): List all available API routes with their methods and implementing providers. - :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes. + :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes. :returns: Response containing information about all available routes. """ ... diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index 5c976886c..bbb359b51 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -136,7 +136,7 @@ class Models(Protocol): """ ... - @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_model( self, model_id: str, @@ -158,7 +158,7 @@ class Models(Protocol): """ ... - @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_model( self, model_id: str, diff --git a/src/llama_stack/apis/scoring_functions/scoring_functions.py b/src/llama_stack/apis/scoring_functions/scoring_functions.py index fe49723ab..78f4a7541 100644 --- a/src/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/src/llama_stack/apis/scoring_functions/scoring_functions.py @@ -178,7 +178,7 @@ class ScoringFunctions(Protocol): """ ... - @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_scoring_function( self, scoring_fn_id: str, @@ -199,7 +199,9 @@ class ScoringFunctions(Protocol): """ ... - @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod( + route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True + ) async def unregister_scoring_function(self, scoring_fn_id: str) -> None: """Unregister a scoring function. diff --git a/src/llama_stack/apis/shields/shields.py b/src/llama_stack/apis/shields/shields.py index ca4483828..659ba8b75 100644 --- a/src/llama_stack/apis/shields/shields.py +++ b/src/llama_stack/apis/shields/shields.py @@ -67,7 +67,7 @@ class Shields(Protocol): """ ... - @webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_shield( self, shield_id: str, @@ -85,7 +85,7 @@ class Shields(Protocol): """ ... - @webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_shield(self, identifier: str) -> None: """Unregister a shield. diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index c9bdfcfb6..4e7cf2544 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel): @runtime_checkable @telemetry_traceable class ToolGroups(Protocol): - @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1) + @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) async def register_tool_group( self, toolgroup_id: str, @@ -167,7 +167,7 @@ class ToolGroups(Protocol): """ ... - @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) async def unregister_toolgroup( self, toolgroup_id: str, diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 26c961db3..846c6f191 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel): @json_schema_type -class VectorStoreFileContentsResponse(BaseModel): - """Response from retrieving the contents of a vector store file. +class VectorStoreFileContentResponse(BaseModel): + """Represents the parsed content of a vector store file. - :param file_id: Unique identifier for the file - :param filename: Name of the file - :param attributes: Key-value attributes associated with the file - :param content: List of content items from the file + :param object: The object type, which is always `vector_store.file_content.page` + :param data: Parsed content of the file + :param has_more: Indicates if there are more content pages to fetch + :param next_page: The token for the next page, if any """ - file_id: str - filename: str - attributes: dict[str, Any] - content: list[VectorStoreContent] + object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page" + data: list[VectorStoreContent] + has_more: bool + next_page: str | None = None @json_schema_type @@ -732,12 +732,12 @@ class VectorIO(Protocol): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file. :param vector_store_id: The ID of the vector store containing the file to retrieve. :param file_id: The ID of the file to retrieve. - :returns: A list of InterleavedContent representing the file contents. + :returns: A VectorStoreFileContentResponse representing the file contents. """ ... diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 6352af00f..07b51128f 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -15,7 +15,6 @@ from llama_stack.apis.inspect import ( RouteInfo, VersionInfo, ) -from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis from llama_stack.core.server.routes import get_all_api_routes @@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect): # Helper function to determine if a route should be included based on api_filter def should_include_route(webmethod) -> bool: if api_filter is None: - # Default: only non-deprecated v1 APIs - return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1 + # Default: only non-deprecated APIs + return not webmethod.deprecated elif api_filter == "deprecated": # Special filter: show deprecated routes regardless of their actual level return bool(webmethod.deprecated) diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index 79eac8b46..e5ff2ada9 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -52,7 +52,7 @@ class SafetyRouter(Safety): async def run_shield( self, shield_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index b54217619..9dac461db 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategyStaticConfig, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFilesListInBatchResponse, @@ -338,7 +338,7 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") provider = await self.routing_table.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( diff --git a/src/llama_stack/core/routing_tables/vector_stores.py b/src/llama_stack/core/routing_tables/vector_stores.py index c6c80a01e..f95a4dbe3 100644 --- a/src/llama_stack/core/routing_tables/vector_stores.py +++ b/src/llama_stack/core/routing_tables/vector_stores.py @@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import ( SearchRankingOptions, VectorStoreChunkingStrategy, VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFileStatus, @@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: await self.assert_action_allowed("read", "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( diff --git a/src/llama_stack/distributions/oci/__init__.py b/src/llama_stack/distributions/oci/__init__.py new file mode 100644 index 000000000..68c0efe44 --- /dev/null +++ b/src/llama_stack/distributions/oci/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .oci import get_distribution_template # noqa: F401 diff --git a/src/llama_stack/distributions/oci/build.yaml b/src/llama_stack/distributions/oci/build.yaml new file mode 100644 index 000000000..7e082e1f6 --- /dev/null +++ b/src/llama_stack/distributions/oci/build.yaml @@ -0,0 +1,35 @@ +version: 2 +distribution_spec: + description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM + inference with scalable cloud services + providers: + inference: + - provider_type: remote::oci + vector_io: + - provider_type: inline::faiss + - provider_type: remote::chromadb + - provider_type: remote::pgvector + safety: + - provider_type: inline::llama-guard + agents: + - provider_type: inline::meta-reference + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs +image_type: venv +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/src/llama_stack/distributions/oci/doc_template.md b/src/llama_stack/distributions/oci/doc_template.md new file mode 100644 index 000000000..320530ccd --- /dev/null +++ b/src/llama_stack/distributions/oci/doc_template.md @@ -0,0 +1,140 @@ +--- +orphan: true +--- +# OCI Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} {{ model.doc_string }}` +{% endfor %} +{% endif %} + +## Prerequisites +### Oracle Cloud Infrastructure Setup + +Before using the OCI Generative AI distribution, ensure you have: + +1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/) +2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy +3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models +4. **Authentication**: Configure authentication using either: + - **Instance Principal** (recommended for cloud-hosted deployments) + - **API Key** (for on-premises or development environments) + +### Authentication Methods + +#### Instance Principal Authentication (Recommended) +Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments. + +Requirements: +- Instance must be running in an Oracle Cloud Infrastructure compartment +- Instance must have appropriate IAM policies to access Generative AI services + +#### API Key Authentication +For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file. + +### Required IAM Policies + +Ensure your OCI user or instance has the following policy statements: + +``` +Allow group to use generative-ai-inference-endpoints in compartment +Allow group to manage generative-ai-inference-endpoints in compartment +``` + +## Supported Services + +### Inference: OCI Generative AI +Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports: + +- **Chat Completions**: Conversational AI with context awareness +- **Text Generation**: Complete prompts and generate text content + +#### Available Models +Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models. + +### Safety: Llama Guard +For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide: +- Content filtering and moderation +- Policy compliance checking +- Harmful content detection + +### Vector Storage: Multiple Options +The distribution supports several vector storage providers: +- **FAISS**: Local in-memory vector search +- **ChromaDB**: Distributed vector database +- **PGVector**: PostgreSQL with vector extensions + +### Additional Services +- **Dataset I/O**: Local filesystem and Hugging Face integration +- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities +- **Evaluation**: Meta reference evaluation framework + +## Running Llama Stack with OCI + +You can run the OCI distribution via Docker or local virtual environment. + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci +``` + +### Configuration Examples + +#### Using Instance Principal (Recommended for Production) +```bash +export OCI_AUTH_TYPE=instance_principal +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1.. +``` + +#### Using API Key Authentication (Development) +```bash +export OCI_AUTH_TYPE=config_file +export OCI_CONFIG_FILE_PATH=~/.oci/config +export OCI_CLI_PROFILE=DEFAULT +export OCI_REGION=us-chicago-1 +export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id +``` + +## Regional Endpoints + +OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit: + +https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions + +## Troubleshooting + +### Common Issues + +1. **Authentication Errors**: Verify your OCI credentials and IAM policies +2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region +3. **Permission Denied**: Check compartment permissions and Generative AI service access +4. **Region Unavailable**: Verify the specified region supports Generative AI services + +### Getting Help + +For additional support: +- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) +- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues) \ No newline at end of file diff --git a/src/llama_stack/distributions/oci/oci.py b/src/llama_stack/distributions/oci/oci.py new file mode 100644 index 000000000..1f21840f1 --- /dev/null +++ b/src/llama_stack/distributions/oci/oci.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings +from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.remote.inference.oci.config import OCIConfig + + +def get_distribution_template(name: str = "oci") -> DistributionTemplate: + providers = { + "inference": [BuildProvider(provider_type="remote::oci")], + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), + ], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], + "datasetio": [ + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), + ], + "scoring": [ + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), + ], + "tool_runtime": [ + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), + BuildProvider(provider_type="remote::model-context-protocol"), + ], + "files": [BuildProvider(provider_type="inline::localfs")], + } + + inference_provider = Provider( + provider_id="oci", + provider_type="remote::oci", + config=OCIConfig.sample_run_config(), + ) + + vector_io_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + + files_provider = Provider( + provider_id="meta-reference-files", + provider_type="inline::localfs", + config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ] + + return DistributionTemplate( + name=name, + distro_type="remote_hosted", + description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + "vector_io": [vector_io_provider], + "files": [files_provider], + }, + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "OCI_AUTH_TYPE": ( + "instance_principal", + "OCI authentication type (instance_principal or config_file)", + ), + "OCI_REGION": ( + "", + "OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)", + ), + "OCI_COMPARTMENT_OCID": ( + "", + "OCI compartment ID for the Generative AI service", + ), + "OCI_CONFIG_FILE_PATH": ( + "~/.oci/config", + "OCI config file path (required if OCI_AUTH_TYPE is config_file)", + ), + "OCI_CLI_PROFILE": ( + "DEFAULT", + "OCI CLI profile name to use from config file", + ), + }, + ) diff --git a/src/llama_stack/distributions/oci/run.yaml b/src/llama_stack/distributions/oci/run.yaml new file mode 100644 index 000000000..e385ec606 --- /dev/null +++ b/src/llama_stack/distributions/oci/run.yaml @@ -0,0 +1,136 @@ +version: 2 +image_name: oci +apis: +- agents +- datasetio +- eval +- files +- inference +- safety +- scoring +- tool_runtime +- vector_io +providers: + inference: + - provider_id: oci + provider_type: remote::oci + config: + oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal} + oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config} + oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT} + oci_region: ${env.OCI_REGION:=us-ashburn-1} + oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + persistence: + namespace: vector_io::faiss + backend: kv_default + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence: + agent_state: + namespace: agents + backend: kv_default + responses: + table_name: responses + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + namespace: eval + backend: kv_default + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + namespace: datasetio::huggingface + backend: kv_default + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + namespace: datasetio::localfs + backend: kv_default + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files} + metadata_store: + table_name: files_metadata + backend: sql_default +storage: + backends: + kv_default: + type: kv_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db + sql_default: + type: sql_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + conversations: + table_name: openai_conversations + backend: sql_default + prompts: + namespace: prompts + backend: kv_default +registered_resources: + models: [] + shields: [] + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search +server: + port: 8321 +telemetry: + enabled: true diff --git a/src/llama_stack/models/llama/llama3/generation.py b/src/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..9ac215c3b 100644 --- a/src/llama_stack/models/llama/llama3/generation.py +++ b/src/llama_stack/models/llama/llama3/generation.py @@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import ( ) from termcolor import cprint +from llama_stack.models.llama.datatypes import ToolPromptFormat + from ..checkpoint import maybe_reshard_state_dict -from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage from .args import ModelArgs from .chat_format import ChatFormat, LLMInput from .model import Transformer diff --git a/src/llama_stack/models/llama/llama3/interface.py b/src/llama_stack/models/llama/llama3/interface.py index b63ba4847..89be31a55 100644 --- a/src/llama_stack/models/llama/llama3/interface.py +++ b/src/llama_stack/models/llama/llama3/interface.py @@ -15,13 +15,10 @@ from pathlib import Path from termcolor import colored +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat + from ..datatypes import ( - BuiltinTool, RawMessage, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, ) from . import template_data from .chat_format import ChatFormat diff --git a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 11a5993e9..3fbaa103e 100644 --- a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any -from llama_stack.apis.inference import ( +from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolDefinition, ) diff --git a/src/llama_stack/models/llama/llama3/tool_utils.py b/src/llama_stack/models/llama/llama3/tool_utils.py index 8c12fe680..6f919e1fa 100644 --- a/src/llama_stack/models/llama/llama3/tool_utils.py +++ b/src/llama_stack/models/llama/llama3/tool_utils.py @@ -8,8 +8,9 @@ import json import re from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat -from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from ..datatypes import RecursiveType logger = get_logger(name=__name__, category="models::llama") diff --git a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py index 1ee570933..feded9f8c 100644 --- a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -13,7 +13,7 @@ import textwrap -from llama_stack.apis.inference import ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition from llama_stack.models.llama.llama3.prompt_templates.base import ( PromptTemplate, PromptTemplateGeneratorBase, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index 7141d58bc..880e0b680 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents): include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, + max_tool_calls: int | None = None, ) -> OpenAIResponseObject: assert self.openai_responses_impl is not None, "OpenAI responses not initialized" result = await self.openai_responses_impl.create_openai_response( @@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents): include, max_infer_iters, guardrails, + max_tool_calls, ) return result # type: ignore[no-any-return] diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 933cfe963..ed7f959c0 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -255,6 +255,7 @@ class OpenAIResponsesImpl: include: list[str] | None = None, max_infer_iters: int | None = 10, guardrails: list[str | ResponseGuardrailSpec] | None = None, + max_tool_calls: int | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text @@ -270,6 +271,9 @@ class OpenAIResponsesImpl: if not conversation.startswith("conv_"): raise InvalidConversationIdError(conversation) + if max_tool_calls is not None and max_tool_calls < 1: + raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1") + stream_gen = self._create_streaming_response( input=input, conversation=conversation, @@ -282,6 +286,7 @@ class OpenAIResponsesImpl: tools=tools, max_infer_iters=max_infer_iters, guardrail_ids=guardrail_ids, + max_tool_calls=max_tool_calls, ) if stream: @@ -331,6 +336,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, max_infer_iters: int | None = 10, guardrail_ids: list[str] | None = None, + max_tool_calls: int | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # These should never be None when called from create_openai_response (which sets defaults) # but we assert here to help mypy understand the types @@ -373,6 +379,7 @@ class OpenAIResponsesImpl: safety_api=self.safety_api, guardrail_ids=guardrail_ids, instructions=instructions, + max_tool_calls=max_tool_calls, ) # Stream the response diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index ef5603420..c16bc8df3 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -115,6 +115,7 @@ class StreamingResponseOrchestrator: safety_api, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, + max_tool_calls: int | None = None, ): self.inference_api = inference_api self.ctx = ctx @@ -126,6 +127,10 @@ class StreamingResponseOrchestrator: self.safety_api = safety_api self.guardrail_ids = guardrail_ids or [] self.prompt = prompt + # System message that is inserted into the model's context + self.instructions = instructions + # Max number of total calls to built-in tools that can be processed in a response + self.max_tool_calls = max_tool_calls self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( @@ -139,8 +144,8 @@ class StreamingResponseOrchestrator: self.accumulated_usage: OpenAIResponseUsage | None = None # Track if we've sent a refusal response self.violation_detected = False - # system message that is inserted into the model's context - self.instructions = instructions + # Track total calls made to built-in tools + self.accumulated_builtin_tool_calls = 0 async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" @@ -186,6 +191,7 @@ class StreamingResponseOrchestrator: usage=self.accumulated_usage, instructions=self.instructions, prompt=self.prompt, + max_tool_calls=self.max_tool_calls, ) async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: @@ -894,6 +900,11 @@ class StreamingResponseOrchestrator: """Coordinate execution of both function and non-function tool calls.""" # Execute non-function tool calls for tool_call in non_function_tool_calls: + # Check if total calls made to built-in and mcp tools exceed max_tool_calls + if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls: + logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.") + break + # Find the item_id for this tool call matching_item_id = None for index, item_id in completion_result_data.tool_call_item_ids.items(): @@ -974,6 +985,9 @@ class StreamingResponseOrchestrator: if tool_response_message: next_turn_messages.append(tool_response_message) + # Track number of calls made to built-in and mcp tools + self.accumulated_builtin_tool_calls += 1 + # Execute function tool calls (client-side) for tool_call in function_tool_calls: # Find the item_id for this tool call from our tracking dictionary diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index cb926f529..51a2ddfad 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import math -from collections.abc import Generator from typing import Optional import torch @@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIResponseFormatJSONSchema, ResponseFormat, + ResponseFormatType, SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_types import Model, ModelFamily -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, - get_default_tool_prompt_format, -) from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig @@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams): return temperature, top_p -def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): - tool_config = request.tool_config - if tool_config is not None and tool_config.tool_prompt_format is not None: - return tool_config.tool_prompt_format - else: - return get_default_tool_prompt_format(request.model) - - class LlamaGenerator: def __init__( self, @@ -157,55 +146,56 @@ class LlamaGenerator: self.args = self.inner_generator.args self.formatter = self.inner_generator.formatter - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(first_request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), - ) - def chat_completion( self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() + request: OpenAIChatCompletionRequestWithExtraBody, + raw_messages: list, + ): + """Generate chat completion using OpenAI request format. + + Args: + request: OpenAI chat completion request + raw_messages: Pre-converted list of RawMessage objects + """ + + # Determine tool prompt format + tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json + + # Prepare sampling params + sampling_params = SamplingParams() + if request.temperature is not None or request.top_p is not None: + sampling_params.strategy = TopPSamplingStrategy( + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, + ) + if request.max_tokens: + sampling_params.max_tokens = request.max_tokens + max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) + + # Get logits processor for response format + logits_processor = None + if request.response_format: + if isinstance(request.response_format, OpenAIResponseFormatJSONSchema): + # Extract the actual schema from OpenAIJSONSchema TypedDict + schema_dict = request.response_format.json_schema.get("schema") or {} + json_schema_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, + json_schema=schema_dict, + ) + logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) + + # Generate yield from self.inner_generator.generate( - llm_inputs=[ - self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) - for request in request_batch - ], + llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(first_request.logprobs), + logprobs=False, echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), + logits_processor=logits_processor, ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 76d3fdd50..ef21132a0 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,12 +5,19 @@ # the root directory of this source tree. import asyncio +import time +import uuid from collections.abc import AsyncIterator from llama_stack.apis.inference import ( InferenceProvider, + OpenAIAssistantMessageParam, OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, + ToolChoice, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat +from llama_stack.models.llama.llama3.prompt_templates import ( + JsonCustomToolGenerator, + SystemDefaultGenerator, +) from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) +def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition: + """Convert OpenAI tool format to ToolDefinition format.""" + # OpenAI tools have function.name and function.parameters + return ToolDefinition( + tool_name=tool.function.name, + description=tool.function.description or "", + parameters=tool.function.parameters or {}, + ) + + +def _get_tool_choice_prompt(tool_choice, tools) -> str: + """Generate prompt text for tool_choice behavior.""" + if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto": + return "" + elif tool_choice == ToolChoice.required or tool_choice == "required": + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none or tool_choice == "none": + return "" + else: + # Specific tool specified + return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def _raw_content_as_str(content) -> str: + """Convert RawContent to string for system messages.""" + if isinstance(content, str): + return content + elif isinstance(content, RawTextItem): + return content.text + elif isinstance(content, list): + return "\n".join(_raw_content_as_str(c) for c in content) + else: + return "" + + +def _augment_raw_messages_for_tools_llama_3_1( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 3.1 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions first (if present) + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # For OpenAI format, all tools are custom (have string names) + tool_gen = JsonCustomToolGenerator() + tool_template = tool_gen.gen(tool_definitions) + sys_content += tool_template.render() + sys_content += "\n" + + # Add default system prompt + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + sys_content += default_template.render() + + # Add existing system message if present + if existing_system_message: + sys_content += "\n" + _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + # Create new system message + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + + return [new_system_message] + messages + + +def _augment_raw_messages_for_tools_llama_4( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions if present + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # Use python_list format for Llama 4 + tool_gen = PythonListCustomToolGeneratorLlama4() + system_prompt = None + if existing_system_message: + system_prompt = _raw_content_as_str(existing_system_message.content) + + tool_template = tool_gen.gen(tool_definitions, system_prompt) + sys_content = tool_template.render() + elif existing_system_message: + # No tools, just use existing system message + sys_content = _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + if sys_content: + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + return [new_system_message] + messages + + return messages + + +def augment_raw_messages_for_tools( + raw_messages: list[RawMessage], + params: OpenAIChatCompletionRequestWithExtraBody, + llama_model, +) -> list[RawMessage]: + """Augment raw messages with tool definitions based on model family.""" + if not params.tools: + return raw_messages + + # Determine augmentation strategy based on model family + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # Llama 3.1 and Llama 3.2 multimodal use JSON format + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + elif llama_model.model_family in ( + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ): + # Llama 3.2/3.3/4 use python_list format + return _augment_raw_messages_for_tools_llama_4( + raw_messages, + params.tools, + params.tool_choice, + ) + else: + # Default to Llama 3.1 style + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") + await self.openai_chat_completion( - model=model_id, - messages=[{"role": "user", "content": "Hi how are you?"}], - max_tokens=20, + params=OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")], + max_tokens=20, + ) ) log.info("Warmed up!") @@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") + self.check_model(params) + + # Convert OpenAI messages to RawMessages + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, + decode_assistant_message, + ) + + raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + + # Augment messages with tool definitions if tools are present + raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model) + + # Call generator's chat_completion method (works for both single-GPU and model-parallel) + if isinstance(self.generator, LlamaGenerator): + generator = self.generator.chat_completion(params, raw_messages) + else: + # Model parallel: submit task to process group + generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) + + # Check if streaming is requested + if params.stream: + return self._stream_chat_completion(generator, params) + + # Non-streaming: collect all generated text + generated_text = "" + for result_batch in generator: + for result in result_batch: + if not result.ignore_token and result.source == "output": + generated_text += result.text + + # Decode assistant message to extract tool calls and determine stop_reason + # Default to end_of_turn if generation completed normally + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # Convert tool calls to OpenAI format + openai_tool_calls = None + if decoded_message.tool_calls: + from llama_stack.apis.inference import ( + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + ) + + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Determine finish_reason based on whether tool calls are present + finish_reason = "tool_calls" if openai_tool_calls else "stop" + + # Extract content from decoded message + content = "" + if isinstance(decoded_message.content, str): + content = decoded_message.content + elif isinstance(decoded_message.content, list): + for item in decoded_message.content: + if isinstance(item, RawTextItem): + content += item.text + + # Create OpenAI response + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + return OpenAIChatCompletion( + id=response_id, + object="chat.completion", + created=created, + model=params.model, + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=content, + tool_calls=openai_tool_calls, + ), + finish_reason=finish_reason, + logprobs=None, + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=0, # TODO: calculate properly + completion_tokens=0, # TODO: calculate properly + total_tokens=0, # TODO: calculate properly + ), + ) + + async def _stream_chat_completion( + self, + generator, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream chat completion chunks as they're generated.""" + from llama_stack.apis.inference import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoiceDelta, + OpenAIChunkChoice, + ) + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + generated_text = "" + + # Yield chunks as tokens are generated + for result_batch in generator: + for result in result_batch: + if result.ignore_token or result.source != "output": + continue + + generated_text += result.text + + # Yield delta chunk with the new text + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + content=result.text, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + # After generation completes, decode the full message to extract tool calls + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # If tool calls are present, yield a final chunk with tool_calls + if decoded_message.tool_calls: + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Yield chunk with tool_calls + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + tool_calls=openai_tool_calls, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Yield final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta(), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + yield final_chunk diff --git a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 9d0295d65..f50b41f34 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -4,17 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Callable, Generator -from copy import deepcopy +from collections.abc import Callable from functools import partial from typing import Any from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) from .parallel_utils import ModelParallelProcessGroup @@ -23,12 +18,14 @@ class ModelRunner: def __init__(self, llama): self.llama = llama - # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` def __call__(self, task: Any): - if task[0] == "chat_completion": - return self.llama.chat_completion(task[1]) + task_type = task[0] + if task_type == "chat_completion": + # task[1] is [params, raw_messages] + params, raw_messages = task[1] + return self.llama.chat_completion(params, raw_messages) else: - raise ValueError(f"Unexpected task type {task[0]}") + raise ValueError(f"Unexpected task type {task_type}") def init_model_cb( @@ -78,19 +75,3 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("completion", req_obj)) - yield from gen - - def chat_completion( - self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("chat_completion", req_obj)) - yield from gen diff --git a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index bb6a1bd03..663e4793b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) log = get_logger(name=__name__, category="inference") @@ -69,10 +65,7 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ] + task: tuple[str, list] class TaskResponse(BaseModel): @@ -328,10 +321,7 @@ class ModelParallelProcessGroup: def run_inference( self, - req: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ], + req: tuple[str, list], ) -> Generator: assert not self.running, "inference already running" diff --git a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cb72aa13a..e6dcf3ae7 100644 --- a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, diff --git a/src/llama_stack/providers/registry/inference.py b/src/llama_stack/providers/registry/inference.py index 1b70182fc..3cbfd408b 100644 --- a/src/llama_stack/providers/registry/inference.py +++ b/src/llama_stack/providers/registry/inference.py @@ -297,6 +297,20 @@ Available Models: Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview +""", + ), + RemoteProviderSpec( + api=Api.inference, + provider_type="remote::oci", + adapter_type="oci", + pip_packages=["oci"], + module="llama_stack.providers.remote.inference.oci", + config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig", + provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator", + description=""" +Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. +Provider documentation +https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm """, ), ] diff --git a/src/llama_stack/providers/remote/inference/oci/__init__.py b/src/llama_stack/providers/remote/inference/oci/__init__.py new file mode 100644 index 000000000..280a8c1d2 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import InferenceProvider + +from .config import OCIConfig + + +async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider: + from .oci import OCIInferenceAdapter + + adapter = OCIInferenceAdapter(config=config) + await adapter.initialize() + return adapter diff --git a/src/llama_stack/providers/remote/inference/oci/auth.py b/src/llama_stack/providers/remote/inference/oci/auth.py new file mode 100644 index 000000000..f64436eb5 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/auth.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Generator, Mapping +from typing import Any, override + +import httpx +import oci +import requests +from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE + +OciAuthSigner = type[oci.signer.AbstractBaseSigner] + + +class HttpxOciAuth(httpx.Auth): + """ + Custom HTTPX authentication class that implements OCI request signing. + + This class handles the authentication flow for HTTPX requests by signing them + using the OCI Signer, which adds the necessary authentication headers for + OCI API calls. + + Attributes: + signer (oci.signer.Signer): The OCI signer instance used for request signing + """ + + def __init__(self, signer: OciAuthSigner): + self.signer = signer + + @override + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + # Read the request content to handle streaming requests properly + try: + content = request.content + except httpx.RequestNotRead: + # For streaming requests, we need to read the content first + content = request.read() + + req = requests.Request( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=content, + ) + prepared_request = req.prepare() + + # Sign the request using the OCI Signer + self.signer.do_request_sign(prepared_request) # type: ignore + + # Update the original HTTPX request with the signed headers + request.headers.update(prepared_request.headers) + + yield request + + +class OciInstancePrincipalAuth(HttpxOciAuth): + def __init__(self, **kwargs: Mapping[str, Any]): + self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs) + + +class OciUserPrincipalAuth(HttpxOciAuth): + def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE): + config = oci.config.from_file(config_file, profile_name) + oci.config.validate_config(config) # type: ignore + key_content = "" + with open(config["key_file"]) as f: + key_content = f.read() + + self.signer = oci.signer.Signer( + tenancy=config["tenancy"], + user=config["user"], + fingerprint=config["fingerprint"], + private_key_file_location=config.get("key_file"), + pass_phrase="none", # type: ignore + private_key_content=key_content, + ) diff --git a/src/llama_stack/providers/remote/inference/oci/config.py b/src/llama_stack/providers/remote/inference/oci/config.py new file mode 100644 index 000000000..9747b08ea --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/config.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from llama_stack.schema_utils import json_schema_type + + +class OCIProviderDataValidator(BaseModel): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + ) + oci_region: str = Field( + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str = Field( + description="OCI compartment ID for the Generative AI service", + ) + oci_config_file_path: str | None = Field( + default="~/.oci/config", + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str | None = Field( + default="DEFAULT", + description="OCI config profile (required if oci_auth_type is config_file)", + ) + + +@json_schema_type +class OCIConfig(RemoteInferenceProviderConfig): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"), + ) + oci_region: str = Field( + default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"), + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str = Field( + default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""), + description="OCI compartment ID for the Generative AI service", + ) + oci_config_file_path: str = Field( + default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"), + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str = Field( + default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"), + description="OCI config profile (required if oci_auth_type is config_file)", + ) + + @classmethod + def sample_run_config( + cls, + oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}", + oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}", + oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}", + oci_region: str = "${env.OCI_REGION:=us-ashburn-1}", + oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}", + **kwargs, + ) -> dict[str, Any]: + return { + "oci_auth_type": oci_auth_type, + "oci_config_file_path": oci_config_file_path, + "oci_config_profile": oci_config_profile, + "oci_region": oci_region, + "oci_compartment_id": oci_compartment_id, + } diff --git a/src/llama_stack/providers/remote/inference/oci/oci.py b/src/llama_stack/providers/remote/inference/oci/oci.py new file mode 100644 index 000000000..253dcf2b6 --- /dev/null +++ b/src/llama_stack/providers/remote/inference/oci/oci.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from collections.abc import Iterable +from typing import Any + +import httpx +import oci +from oci.generative_ai.generative_ai_client import GenerativeAiClient +from oci.generative_ai.models import ModelCollection +from openai._base_client import DefaultAsyncHttpxClient + +from llama_stack.apis.inference.inference import ( + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) +from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger +from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth +from llama_stack.providers.remote.inference.oci.config import OCIConfig +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + +logger = get_logger(name=__name__, category="inference::oci") + +OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal" +OCI_AUTH_TYPE_CONFIG_FILE = "config_file" +VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE] +DEFAULT_OCI_REGION = "us-ashburn-1" + +MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"] + + +class OCIInferenceAdapter(OpenAIMixin): + config: OCIConfig + + async def initialize(self) -> None: + """Initialize and validate OCI configuration.""" + if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES: + raise ValueError( + f"Invalid OCI authentication type: {self.config.oci_auth_type}." + f"Valid types are one of: {VALID_OCI_AUTH_TYPES}" + ) + + if not self.config.oci_compartment_id: + raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.") + + def get_base_url(self) -> str: + region = self.config.oci_region or DEFAULT_OCI_REGION + return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1" + + def get_api_key(self) -> str | None: + # OCI doesn't use API keys, it uses request signing + return "" + + def get_extra_client_params(self) -> dict[str, Any]: + """ + Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers. + """ + auth = self._get_auth() + compartment_id = self.config.oci_compartment_id or "" + + return { + "http_client": DefaultAsyncHttpxClient( + auth=auth, + headers={ + "CompartmentId": compartment_id, + }, + ), + } + + def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + return oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + return None + + def _get_oci_config(self) -> dict: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + config = {"region": self.config.oci_region} + elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE: + config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile) + if not config.get("region"): + raise ValueError( + "Region not specified in config. Please specify in config or with OCI_REGION env variable." + ) + + return config + + def _get_auth(self) -> httpx.Auth: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + return OciInstancePrincipalAuth() + elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE: + return OciUserPrincipalAuth( + config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile + ) + else: + raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}") + + async def list_provider_model_ids(self) -> Iterable[str]: + """ + List available models from OCI Generative AI service. + """ + oci_config = self._get_oci_config() + oci_signer = self._get_oci_signer() + compartment_id = self.config.oci_compartment_id or "" + + if oci_signer is None: + client = GenerativeAiClient(config=oci_config) + else: + client = GenerativeAiClient(config=oci_config, signer=oci_signer) + + models: ModelCollection = client.list_models( + compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE" + ).data + + seen_models = set() + model_ids = [] + for model in models.items: + if model.time_deprecated or model.time_on_demand_retired: + continue + + if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities: + continue + + # Use display_name + model_type as the key to avoid conflicts + model_key = (model.display_name, ModelType.llm) + if model_key in seen_models: + continue + + seen_models.add(model_key) + model_ids.append(model.display_name) + + return model_ids + + async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse: + # The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings. + raise NotImplementedError("OCI Provider does not (currently) support embeddings") diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 223497fb8..a793c499e 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -11,9 +11,7 @@ from collections.abc import AsyncIterator import litellm from llama_stack.apis.inference import ( - ChatCompletionRequest, InferenceProvider, - JsonSchemaResponseFormat, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, @@ -23,15 +21,11 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - ToolChoice, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, - get_sampling_options, prepare_openai_completion_params, ) @@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin( return schema - async def _get_params(self, request: ChatCompletionRequest) -> dict: - from typing import Any - - input_dict: dict[str, Any] = {} - - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages - ] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - # Convert to dict for manipulation - fmt_dict = dict(fmt.json_schema) - name = fmt_dict["title"] - del fmt_dict["title"] - fmt_dict["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt_dict = self._add_additional_properties_recursive(fmt_dict) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt_dict, - "strict": self.json_schema_strict, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config and (tool_choice := request.tool_config.tool_choice): - input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice - - return { - "model": request.model, - "api_key": self.get_api_key(), - "api_base": self.api_base, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index aabcb50f8..c2e6829e0 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -3,31 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import time -import uuid -import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable +from collections.abc import Iterable from typing import ( Any, ) -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) try: from openai.types.chat import ( @@ -37,84 +20,24 @@ except ImportError: from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, ImageContentItem, - InterleavedContent, TextContentItem, - TextDelta, - ToolCallDelta, - ToolCallParseStatus, _URLOrData, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, - Message, - OpenAIChatCompletion, - OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, - SystemMessage, - TokenLogProbs, - ToolChoice, - ToolConfig, - ToolResponseMessage, TopKSamplingStrategy, TopPSamplingStrategy, - UserMessage, -) -from llama_stack.apis.inference import ( - OpenAIChoice as OpenAIChatCompletionChoice, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -123,10 +46,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, - decode_assistant_message, -) logger = get_logger(name=__name__, category="providers::utils") @@ -213,345 +132,6 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens -def convert_openai_completion_logprobs( - logprobs: OpenAICompatLogprobs | None, -) -> list[TokenLogProbs] | None: - if not logprobs: - return None - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - - # Together supports logprobs with top_k=1 only. This means for each token position, - # they return only the logprobs for the selected token (vs. the top n most likely tokens). - # Here we construct the response by matching the selected token with the logprobs. - if logprobs.tokens and logprobs.token_logprobs: - return [ - TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) - ] - return None - - -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): - if logprobs is None: - return None - if isinstance(logprobs, float): - # Adapt response from Together CompletionChoicesChunk - return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - return None - - -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - text = choice.text or "" - # drop suffix if present and return stop reason as end of turn - if text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason or "stop"), - content=text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - - -def process_chat_completion_response( - response: OpenAICompatCompletionResponse, - request: ChatCompletionRequest, -) -> ChatCompletionResponse: - choice = response.choices[0] - if choice.finish_reason == "tool_calls": - if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - raise ValueError("Tool calls are not present in the response") - - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): - # If we couldn't parse a tool call, jsonify the tool calls and return them - return ChatCompletionResponse( - completion_message=CompletionMessage( - stop_reason=StopReason.end_of_turn, - content=json.dumps(tool_calls, default=lambda x: x.model_dump()), - ), - logprobs=None, - ) - else: - # Otherwise, return tool calls as normal - # Filter to only valid ToolCall objects - valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=valid_tool_calls, - stop_reason=StopReason.end_of_turn, - # Content is not optional - content="", - ), - logprobs=None, - ) - - # TODO: This does not work well with tool calls for vLLM remote provider - # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) - - # NOTE: If we do not set tools in chat-completion request, we should not - # expect the ToolCall in the response. Instead, we should return the raw - # response from the model. - if raw_message.tool_calls: - if not request.tools: - raw_message.tool_calls = [] - raw_message.content = text_from_choice(choice) - else: - # only return tool_calls if provided in the request - new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} - for t in raw_message.tool_calls: - if t.tool_name in request_tools: - new_tool_calls.append(t) - else: - logger.warning(f"Tool {t.tool_name} not found in request tools") - - if len(new_tool_calls) < len(raw_message.tool_calls): - raw_message.tool_calls = new_tool_calls - raw_message.content = text_from_choice(choice) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] - stop_reason=raw_message.stop_reason or StopReason.end_of_turn, - tool_calls=raw_message.tool_calls, - ), - logprobs=None, - ) - - -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) - - -async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - request: ChatCompletionRequest, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif stop_reason is None and finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - text = text_from_choice(choice) - if not text: - # Sometimes you get empty chunks from providers - continue - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - if ipython: - buffer += text - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=TextDelta(text=text), - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - request_tools = {t.tool_name: t for t in (request.tools or [])} - for tool_call in message.tool_calls: - if tool_call.tool_name in request_tools: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - else: - logger.warning(f"Tool {tool_call.tool_name} not found in request tools") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - # Parsing tool call failed due to tool call not being found in request tools, - # We still add the raw message text inside tool_call for responding back to the user - tool_call=buffer, - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - return { - "type": "image_url", - "image_url": { - "url": await convert_image_content_to_url(content, download=download), - }, - } - else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} - - if isinstance(message.content, list): - content = [await _convert_content(c) for c in message.content] - else: - content = [await _convert_content(message.content)] - - result = { - "role": message.role, - "content": content, - } - - if hasattr(message, "tool_calls") and message.tool_calls: - tool_calls_list = [] - for tc in message.tool_calls: - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tc.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - - tool_calls_list.append( - { - "id": tc.call_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": tc.arguments, - }, - } - ) - result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected - return result - - class UnparseableToolCall(BaseModel): """ A ToolCall with arguments that are not valid JSON. @@ -563,112 +143,6 @@ class UnparseableToolCall(BaseModel): arguments: str = "" -async def convert_message_to_openai_dict_new( - message: Message | dict, - download_images: bool = False, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) - ), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - tool_calls = [ - OpenAIChatCompletionMessageFunctionToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), - arguments=tool.arguments, # Already a JSON string, don't double-encode - ), - type="function", - ) - for tool in (message.tool_calls or []) - ] - params = {} - if tool_calls: - params["tool_calls"] = tool_calls - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - - def convert_tool_call( tool_call: ChatCompletionMessageToolCall, ) -> ToolCall | UnparseableToolCall: @@ -817,17 +291,6 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: - tool_config = ToolConfig() - if tool_choice: - try: - tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception - except ValueError: - pass - tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type - return tool_config - - def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools: list[ToolDefinition] = [] if not tools: @@ -898,40 +361,6 @@ def _convert_openai_tool_calls( ] -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> list[TokenLogProbs] | None: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs or not logprobs.content: - return None - - return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) - for content in logprobs.content - ] - - def _convert_openai_sampling_params( max_tokens: int | None = None, temperature: float | None = None, @@ -956,37 +385,6 @@ def _convert_openai_sampling_params( return sampling_params -def openai_messages_to_messages( - messages: list[OpenAIMessageParam], -) -> list[Message]: - """ - Convert a list of OpenAIChatCompletionMessage into a list of Message. - """ - converted_messages: list[Message] = [] - for message in messages: - converted_message: Message - if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "assistant": - converted_message = CompletionMessage( - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function - stop_reason=StopReason.end_of_turn, - ) - elif message.role == "tool": - converted_message = ToolResponseMessage( - role="tool", - call_id=message.tool_call_id, - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - ) - else: - raise ValueError(f"Unknown role {message.role}") - converted_messages.append(converted_message) - return converted_messages - - def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): if content is None: return "" @@ -1005,216 +403,6 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten raise ValueError(f"Unknown content type: {content}") -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union - ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], - enable_incremental_tool_calls: bool, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - """ - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - event_type = ChatCompletionResponseEventType.progress - - stop_reason = None - tool_call_idx_to_buffer = {} - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason - logprobs = getattr(choice, "logprobs", None) - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", - stacklevel=2, - ) - - if not enable_incremental_tool_calls: - for tool_call in choice.delta.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - else: - for tool_call in choice.delta.tool_calls: - idx = tool_call.index if hasattr(tool_call, "index") else 0 - - if idx not in tool_call_idx_to_buffer: - tool_call_idx_to_buffer[idx] = { - "call_id": tool_call.id, - "name": None, - "arguments": "", - "content": "", - } - - buffer = tool_call_idx_to_buffer[idx] - - if tool_call.function: - if tool_call.function.name: - buffer["name"] = tool_call.function.name - delta = f"{buffer['name']}(" - if buffer["content"] is not None: - buffer["content"] += delta - - if tool_call.function.arguments: - delta = tool_call.function.arguments - if buffer["arguments"] is not None and delta: - buffer["arguments"] += delta - if buffer["content"] is not None and delta: - buffer["content"] += delta - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - elif choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - for idx, buffer in tool_call_idx_to_buffer.items(): - logger.debug(f"toolcall_buffer[{idx}]: {buffer}") - if buffer["name"]: - delta = ")" - if buffer["content"] is not None: - buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=None, - ) - ) - - try: - parsed_tool_call = ToolCall( - call_id=buffer["call_id"] or "", - tool_name=buffer["name"] or "", - arguments=buffer["arguments"] or "", - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - except json.JSONDecodeError as e: - print(f"Failed to parse arguments: {e}") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - async def prepare_openai_completion_params(**params): async def _prepare_value(value: Any) -> Any: new_value = value @@ -1233,163 +421,6 @@ async def prepare_openai_completion_params(**params): return completion_params -class OpenAIChatCompletionToLlamaStackMixin: - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format - response_format = _convert_openai_request_response_format(response_format) - sampling_params = _convert_openai_sampling_params( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - ) - tool_config = _convert_openai_request_tool_config(tool_choice) - - tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format - if tool_config.tool_choice == ToolChoice.none: - tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type - - outstanding_responses = [] - # "n" is the number of completions to generate per prompt - n = n or 1 - for _i in range(0, n): - response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion - model_id=model, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - stream=stream, - tool_config=tool_config, - tools=tools, - ) - outstanding_responses.append(response) - - if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy - - return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( - self, model, outstanding_responses - ) - - async def _process_stream_response( - self, - model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], - ): - id = f"chatcmpl-{uuid.uuid4()}" - for i, outstanding_response in enumerate(outstanding_responses): - response = await outstanding_response - async for chunk in response: - event = chunk.event - finish_reason = ( - _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None - ) - - if isinstance(event.delta, TextDelta): - text_delta = event.delta.text - delta = OpenAIChoiceDelta(content=text_delta) - yield OpenAIChatCompletionChunk( - id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - elif isinstance(event.delta, ToolCallDelta): - if event.delta.parse_status == ToolCallParseStatus.succeeded: - tool_call = event.delta.tool_call - if isinstance(tool_call, str): - continue - - # First chunk includes full structure - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - id=tool_call.call_id, - function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy - arguments="", - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - # arguments - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - function=OpenAIChoiceDeltaToolCallFunction( - arguments=tool_call.arguments, - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - - async def _process_non_stream_response( - self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] - ) -> OpenAIChatCompletion: - choices: list[OpenAIChatCompletionChoice] = [] - for outstanding_response in outstanding_responses: - response = await outstanding_response - completion_message = response.completion_message - message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) - - choice = OpenAIChatCompletionChoice( - index=len(choices), - message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type - finish_reason=finish_reason, - ) - choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch - - return OpenAIChatCompletion( - id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible - created=int(time.time()), - model=model, - object="chat.completion", - ) - - def prepare_openai_embeddings_params( model: str, input: str | list[str], diff --git a/src/llama_stack/providers/utils/inference/prompt_adapter.py b/src/llama_stack/providers/utils/inference/prompt_adapter.py index d06b7454d..35a7b3484 100644 --- a/src/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/src/llama_stack/providers/utils/inference/prompt_adapter.py @@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, CompletionRequest, - Message, + OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, OpenAIFile, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, ResponseFormat, ResponseFormatType, - SystemMessage, - SystemMessageBehavior, ToolChoice, - ToolDefinition, - UserMessage, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import ( RawMediaItem, RawMessage, RawTextItem, - Role, StopReason, + ToolCall, + ToolDefinition, ToolPromptFormat, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( - PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, -) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal -from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="providers::utils") -class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: list[RawMessage] - - class CompletionRequestWithRawContent(CompletionRequest): content: RawContent @@ -103,28 +88,6 @@ def interleaved_content_as_str( return _process(content) -async def convert_request_to_raw( - request: ChatCompletionRequest | CompletionRequest, -) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent: - if isinstance(request, ChatCompletionRequest): - messages = [] - for m in request.messages: - content = await interleaved_content_convert_to_raw(m.content) - d = m.model_dump() - d["content"] = content - messages.append(RawMessage(**d)) - - d = request.model_dump() - d["messages"] = messages - request = ChatCompletionRequestWithRawContent(**d) - else: - d = request.model_dump() - d["content"] = await interleaved_content_convert_to_raw(request.content) - request = CompletionRequestWithRawContent(**d) - - return request - - async def interleaved_content_convert_to_raw( content: InterleavedContent, ) -> RawContent: @@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw( return await _localize_single(content) +async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage: + """Convert OpenAI message format to RawMessage format used by Llama formatters.""" + if isinstance(message, OpenAIUserMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="user", content=content) + elif isinstance(message, OpenAISystemMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="system", content=content) + elif isinstance(message, OpenAIAssistantMessageParam): + content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type] + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.function: + tool_calls.append( + ToolCall( + call_id=tc.id or "", + tool_name=tc.function.name or "", + arguments=tc.function.arguments or "{}", + ) + ) + return RawMessage(role="assistant", content=content, tool_calls=tool_calls) + elif isinstance(message, OpenAIToolMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="tool", content=content) + else: + # Handle OpenAIDeveloperMessageParam if needed + raise ValueError(f"Unsupported message type: {type(message)}") + + def content_has_media(content: InterleavedContent): def _has_media_content(c): return isinstance(c, ImageContentItem) @@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent): return _has_media_content(content) -def messages_have_media(messages: list[Message]): - return any(content_has_media(m.content) for m in messages) - - -def request_has_media(request: ChatCompletionRequest | CompletionRequest): - if isinstance(request, ChatCompletionRequest): - return messages_have_media(request.messages) - else: - return content_has_media(request.content) - - async def localize_image_content(uri: str) -> tuple[bytes, str] | None: if uri.startswith("http"): async with httpx.AsyncClient() as client: @@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return formatter.tokenizer.decode(model_input.tokens) - - -async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str -) -> tuple[str, int]: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return ( - formatter.tokenizer.decode(model_input.tokens), - len(model_input.tokens), - ) - - -def chat_completion_request_to_messages( - request: ChatCompletionRequest, - llama_model: str, -) -> list[Message]: - """Reads chat completion request and augments the messages to handle tools. - For eg. for llama_3_1, add system message with the appropriate tools or - add user messsage for custom tools, etc. - """ - assert llama_model is not None, "llama_model is required" - model = resolve_model(llama_model) - if model is None: - log.error(f"Could not resolve model {llama_model}") - return request.messages - - allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models] - if model.descriptor() not in descriptors: - log.error(f"Unsupported inference model? {model.descriptor()}") - return request.messages - - if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) - ): - # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_1(request) - elif model.model_family in ( - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ): - # llama3.2, llama3.3 follow the same tool prompt format - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) - elif model.model_family == ModelFamily.llama4: - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) - else: - messages = request.messages - - if fmt_prompt := response_format_prompt(request.response_format): - messages.append(UserMessage(content=fmt_prompt)) - - return messages - - def response_format_prompt(fmt: ResponseFormat | None): if not fmt: return None @@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None): raise ValueError(f"Unknown response format {fmt.type}") -def augment_messages_for_tools_llama_3_1( - request: ChatCompletionRequest, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - messages = [] - - default_gen = SystemDefaultGenerator() - default_template = default_gen.gen() - - sys_content = "" - - tool_template = None - if request.tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(request.tools) - - sys_content += tool_template.render() - sys_content += "\n" - - sys_content += default_template.render() - - if existing_system_message: - # TODO: this fn is needed in many places - def _process(c): - if isinstance(c, str): - return c - else: - return "" - - sys_content += "\n" - - if isinstance(existing_system_message.content, str): - sys_content += _process(existing_system_message.content) - elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages.append(SystemMessage(content=sys_content)) - - has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json - if fmt == ToolPromptFormat.json: - tool_gen = JsonCustomToolGenerator() - elif fmt == ToolPromptFormat.function_tag: - tool_gen = FunctionTagCustomToolGenerator() - else: - raise ValueError(f"Non supported ToolPromptFormat {fmt}") - - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] - custom_template = tool_gen.gen(custom_tools) - messages.append(UserMessage(content=custom_template.render())) - - # Add back existing messages from the request - messages += existing_messages - - return messages - - -def augment_messages_for_tools_llama( - request: ChatCompletionRequest, - custom_tool_prompt_generator, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - sys_content = "" - custom_tools, builtin_tools = [], [] - for t in request.tools: - if isinstance(t.tool_name, str): - custom_tools.append(t) - else: - builtin_tools.append(t) - - if builtin_tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(builtin_tools) - - sys_content += tool_template.render() - sys_content += "\n" - - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] - if custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list - if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") - - system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: - system_prompt = existing_system_message.content - - tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) - - sys_content += tool_template.render() - sys_content += "\n" - - if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools - ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] - return messages - - def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index d047d9d12..86e6ea013 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( VectorStoreContent, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentsResponse, + VectorStoreFileContentResponse, VectorStoreFileCounts, VectorStoreFileDeleteResponse, VectorStoreFileLastError, @@ -921,22 +921,21 @@ class OpenAIVectorStoreMixin(ABC): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentsResponse: + ) -> VectorStoreFileContentResponse: """Retrieves the contents of a vector store file.""" if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) - file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] content = [] for chunk in chunks: content.extend(self._chunk_to_vector_store_content(chunk)) - return VectorStoreFileContentsResponse( - file_id=file_id, - filename=file_info.get("filename", ""), - attributes=file_info.get("attributes", {}), - content=content, + return VectorStoreFileContentResponse( + object="vector_store.file_content.page", + data=content, + has_more=False, + next_page=None, ) async def openai_update_vector_store_file( diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index d413d5201..057cee774 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode # Verify instructions from previous response was not carried over to the next response assert response_with_instructions2.instructions == instructions2 + + +@pytest.mark.skip(reason="Tool calling is not reliable.") +def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with function tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + max_tool_calls = 1 + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather information for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + { + "type": "function", + "name": "get_time", + "description": "Get current time for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + ] + + # First create a response that triggers function tools + response = client.responses.create( + model=text_model_id, + input="Can you tell me the weather in Paris and the current time?", + tools=tools, + stream=False, + max_tool_calls=max_tool_calls, + ) + + # Verify we got two function calls and that the max_tool_calls do not affect function tools + assert len(response.output) == 2 + assert response.output[0].type == "function_call" + assert response.output[0].name == "get_weather" + assert response.output[0].status == "completed" + assert response.output[1].type == "function_call" + assert response.output[1].name == "get_time" + assert response.output[0].status == "completed" + + # Verify we have a valid max_tool_calls field + assert response.max_tool_calls == max_tool_calls + + +def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id): + """Test handling of invalid max_tool_calls in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology news." + invalid_max_tool_calls = 0 + tools = [ + {"type": "web_search"}, + ] + + # Create a response with an invalid max_tool_calls value i.e. 0 + # Handle ValueError from LLS and BadRequestError from OpenAI client + with pytest.raises((ValueError, BadRequestError)) as excinfo: + client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=invalid_max_tool_calls, + ) + + error_message = str(excinfo.value) + assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, ( + f"Expected error message about invalid max_tool_calls, got: {error_message}" + ) + + +def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id): + """Test handling of max_tool_calls with built-in tools in responses.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + + input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls." + max_tool_calls = [1, 5] + tools = [ + {"type": "web_search"}, + ] + + # First create a response that triggers web_search tools without max_tool_calls + response = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + ) + + # Verify we got two web search calls followed by a message + assert len(response.output) == 3 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + assert response.output[1].type == "web_search_call" + assert response.output[1].status == "completed" + assert response.output[2].type == "message" + assert response.output[2].status == "completed" + assert response.output[2].role == "assistant" + + # Next create a response that triggers web_search tools with max_tool_calls set to 1 + response_2 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[0], + ) + + # Verify we got one web search tool call followed by a message + assert len(response_2.output) == 2 + assert response_2.output[0].type == "web_search_call" + assert response_2.output[0].status == "completed" + assert response_2.output[1].type == "message" + assert response_2.output[1].status == "completed" + assert response_2.output[1].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_2.max_tool_calls == max_tool_calls[0] + + # Finally create a response that triggers web_search tools with max_tool_calls set to 5 + response_3 = client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + max_tool_calls=max_tool_calls[1], + ) + + # Verify we got two web search calls followed by a message + assert len(response_3.output) == 3 + assert response_3.output[0].type == "web_search_call" + assert response_3.output[0].status == "completed" + assert response_3.output[1].type == "web_search_call" + assert response_3.output[1].status == "completed" + assert response_3.output[2].type == "message" + assert response_3.output[2].status == "completed" + assert response_3.output[2].role == "assistant" + + # Verify we have a valid max_tool_calls field + assert response_3.max_tool_calls == max_tool_calls[1] diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 1568ffbe2..4ce2850b4 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos, # or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}} "remote::groq", + "remote::oci", "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 "remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported "remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 704775716..fe8070162 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): "remote::runpod", "remote::sambanova", "remote::tgi", + "remote::oci", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 97ce4abe8..20f9d2978 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -907,16 +907,16 @@ def test_openai_vector_store_retrieve_file_contents( ) assert file_contents is not None - assert len(file_contents.content) == 1 - content = file_contents.content[0] + assert file_contents.object == "vector_store.file_content.page" + assert len(file_contents.data) == 1 + content = file_contents.data[0] # llama-stack-client returns a model, openai-python is a badboy and returns a dict if not isinstance(content, dict): content = content.model_dump() assert content["type"] == "text" assert content["text"] == test_content.decode("utf-8") - assert file_contents.filename == file_name - assert file_contents.attributes == attributes + assert file_contents.has_more is False @vector_provider_wrapper @@ -1483,14 +1483,12 @@ def test_openai_vector_store_file_batch_retrieve_contents( ) assert file_contents is not None - assert file_contents.filename == file_data[i][0] - assert len(file_contents.content) > 0 + assert file_contents.object == "vector_store.file_content.page" + assert len(file_contents.data) > 0 # Verify the content matches what we uploaded content_text = ( - file_contents.content[0].text - if hasattr(file_contents.content[0], "text") - else file_contents.content[0]["text"] + file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"] ) assert file_data[i][1].decode("utf-8") in content_text diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py deleted file mode 100644 index d31426135..000000000 --- a/tests/unit/models/test_prompt_adapter.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - StopReason, - SystemMessage, - SystemMessageBehavior, - ToolCall, - ToolConfig, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, - chat_completion_request_to_prompt, - interleaved_content_as_str, -) - -MODEL = "Llama3.1-8B-Instruct" -MODEL3_2 = "Llama3.2-3B-Instruct" - - -async def test_system_default(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - - -async def test_system_builtin_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - -async def test_system_custom_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_system_custom_and_builtin(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_completion_message_encoding(): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments='{"param1": "value1"}', # arguments must be a JSON string - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '[custom1(param1="value1")]' in prompt - - request.model = MODEL - request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt - - -async def test_user_provided_system_message(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_builtin_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools_with_template(): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "You are a pirate" in interleaved_content_as_str(messages[0].content) - # function description is present in the system prompt - assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content diff --git a/tests/unit/providers/inline/inference/__init__.py b/tests/unit/providers/inline/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/inference/test_meta_reference.py b/tests/unit/providers/inline/inference/test_meta_reference.py new file mode 100644 index 000000000..381836397 --- /dev/null +++ b/tests/unit/providers/inline/inference/test_meta_reference.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.providers.inline.inference.meta_reference.model_parallel import ( + ModelRunner, +) + + +class TestModelRunner: + """Test ModelRunner task dispatching for model-parallel inference.""" + + def test_chat_completion_task_dispatch(self): + """Verify ModelRunner correctly dispatches chat_completion tasks.""" + # Create a mock generator + mock_generator = Mock() + mock_generator.chat_completion = Mock(return_value=iter([])) + + runner = ModelRunner(mock_generator) + + # Create a chat_completion task + fake_params = {"model": "test"} + fake_messages = [{"role": "user", "content": "test"}] + task = ("chat_completion", [fake_params, fake_messages]) + + # Execute task + runner(task) + + # Verify chat_completion was called with correct arguments + mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages) + + def test_invalid_task_type_raises_error(self): + """Verify ModelRunner rejects invalid task types.""" + mock_generator = Mock() + runner = ModelRunner(mock_generator) + + with pytest.raises(ValueError, match="Unexpected task type"): + runner(("invalid_task", [])) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 922d7f61f..622302630 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_stack.apis.inference import CompletionMessage, UserMessage +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.shields import Shield -from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter @@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Mock Guardrails API response mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post): adapter.shield_store.get_shield.return_value = None messages = [ - UserMessage(role="user", content="Hello, how are you?"), + OpenAIUserMessageParam(content="Hello, how are you?"), ] with pytest.raises(ValueError): @@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Running the shield should raise an exception messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py deleted file mode 100644 index c200c4395..000000000 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -from pydantic import ValidationError - -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - CompletionMessage, - OpenAIAssistantMessageParam, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, - SystemMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - convert_message_to_openai_dict_new, - openai_messages_to_messages, -) - - -async def test_convert_message_to_openai_dict(): - message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") - assert await convert_message_to_openai_dict(message) == { - "role": "user", - "content": [{"type": "text", "text": "Hello, world!"}], - } - - -# Test convert_message_to_openai_dict with a tool call -async def test_convert_message_to_openai_dict_with_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_convert_message_to_openai_dict_with_builtin_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ - ToolCall( - call_id="123", - tool_name=BuiltinTool.brave_search, - arguments='{"foo": "bar"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_openai_messages_to_messages_with_content_str(): - openai_messages = [ - OpenAISystemMessageParam(content="system message"), - OpenAIUserMessageParam(content="user message"), - OpenAIAssistantMessageParam(content="assistant message"), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content == "system message" - assert llama_messages[1].content == "user message" - assert llama_messages[2].content == "assistant message" - - -async def test_openai_messages_to_messages_with_content_list(): - openai_messages = [ - OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), - OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]), - OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content[0].text == "system message" - assert llama_messages[1].content[0].text == "user message" - assert llama_messages[2].content[0].text == "assistant message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_string(message_class, kwargs): - """Test that messages accept string text content.""" - msg = message_class(content="Test message", **kwargs) - assert msg.content == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_list(message_class, kwargs): - """Test that messages accept list of text content parts.""" - content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] - msg = message_class(content=content_list, **kwargs) - assert len(msg.content) == 1 - assert msg.content[0].text == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_rejects_images(message_class, kwargs): - """Test that system, assistant, developer, and tool messages reject image content.""" - with pytest.raises(ValidationError): - message_class( - content=[ - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) - ], - **kwargs, - ) - - -def test_user_message_accepts_images(): - """Test that user messages accept image content (unlike other message types).""" - # List with images should work - msg = OpenAIUserMessageParam( - content=[ - OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), - ] - ) - assert len(msg.content) == 2 - assert msg.content[0].text == "Describe this image:" - assert msg.content[1].image_url.url == "http://example.com/image.jpg" - - -async def test_convert_message_to_openai_dict_new_user_message(): - """Test convert_message_to_openai_dict_new with UserMessage.""" - message = UserMessage(content="Hello, world!", role="user") - result = await convert_message_to_openai_dict_new(message) - - assert result["role"] == "user" - assert result["content"] == "Hello, world!" - - -async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): - """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" - message = CompletionMessage( - content="I'll help you find the weather.", - tool_calls=[ - ToolCall( - call_id="call_123", - tool_name="get_weather", - arguments='{"city": "Sligo"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - result = await convert_message_to_openai_dict_new(message) - - # This would have failed with "Cannot instantiate typing.Union" before the fix - assert result["role"] == "assistant" - assert result["content"] == "I'll help you find the weather." - assert "tool_calls" in result - assert result["tool_calls"] is not None - assert len(result["tool_calls"]) == 1 - - tool_call = result["tool_calls"][0] - assert tool_call.id == "call_123" - assert tool_call.type == "function" - assert tool_call.function.name == "get_weather" - assert tool_call.function.arguments == '{"city": "Sligo"}' diff --git a/tests/unit/providers/utils/inference/test_prompt_adapter.py b/tests/unit/providers/utils/inference/test_prompt_adapter.py new file mode 100644 index 000000000..62c8db74d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_prompt_adapter.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.models.llama.datatypes import RawTextItem +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, +) + + +class TestConvertOpenAIMessageToRawMessage: + """Test conversion of OpenAI message types to RawMessage format.""" + + async def test_user_message_conversion(self): + msg = OpenAIUserMessageParam(role="user", content="Hello world") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "user" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hello world" + + async def test_assistant_message_conversion(self): + msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "assistant" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hi there!" + assert raw_msg.tool_calls == []