Merge branch 'main' into evals_6

This commit is contained in:
Xi Yan 2024-10-24 17:29:22 -07:00
commit cdfd584a8f
14 changed files with 466 additions and 171 deletions

77
.github/ISSUE_TEMPLATE/bug.yml vendored Normal file
View file

@ -0,0 +1,77 @@
name: 🐛 Bug Report
description: Create a report to help us reproduce and fix the bug
body:
- type: markdown
attributes:
value: >
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the
existing and past issues](https://github.com/meta-llama/llama-stack/issues).
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please share your system info with us. You can use the following command to capture your environment information
python -m "torch.utils.collect_env"
placeholder: |
PyTorch version, CUDA version, GPU type, #num of GPUs...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: textarea
id: bug-description
attributes:
label: 🐛 Describe the bug
description: |
Please provide a clear and concise description of what the bug is.
Please also paste or describe the results you observe instead of the expected results.
placeholder: |
A clear and concise description of what the bug is.
```llama stack
# Command that you used for running the examples
```
Description of the results
validations:
required: true
- type: textarea
attributes:
label: Error logs
description: |
If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
placeholder: |
```
The error message you got, with the full traceback.
```
validations:
required: true
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

View file

@ -0,0 +1,31 @@
name: 🚀 Feature request
description: Submit a proposal/request for a new llama-stack feature
body:
- type: textarea
id: feature-pitch
attributes:
label: 🚀 The feature, motivation and pitch
description: >
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives
description: >
A description of any alternative solutions or features you've considered, if any.
- type: textarea
id: additional-context
attributes:
label: Additional context
description: >
Add any other context or screenshots about the feature request.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

31
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View file

@ -0,0 +1,31 @@
# What does this PR do?
Closes # (issue)
## Feature/Issue validation/testing/test plan
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration or test plan.
- [ ] Test A
Logs for Test A
- [ ] Test B
Logs for Test B
## Sources
Please link relevant resources if necessary.
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
Thanks for contributing 🎉!

View file

@ -65,23 +65,30 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
## Installation ## Installation
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack` You have two ways to install this repository:
If you want to install from source: 1. **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
```bash 2. **Install from source**:
mkdir -p ~/local If you prefer to install from the source code, follow these steps:
cd ~/local ```bash
git clone git@github.com:meta-llama/llama-stack.git mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10 conda create -n stack python=3.10
conda activate stack conda activate stack
cd llama-stack cd llama-stack
$CONDA_PREFIX/bin/pip install -e . $CONDA_PREFIX/bin/pip install -e .
``` ```
## Documentations ## Documentations

View file

@ -5,163 +5,174 @@ This guide will walk you though the steps to get started on end-to-end flow for
## Installation ## Installation
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack` You have two ways to install this repository:
If you want to install from source: 1. **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
```bash 2. **Install from source**:
mkdir -p ~/local If you prefer to install from the source code, follow these steps:
cd ~/local ```bash
git clone git@github.com:meta-llama/llama-stack.git mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10 conda create -n stack python=3.10
conda activate stack conda activate stack
cd llama-stack cd llama-stack
$CONDA_PREFIX/bin/pip install -e . $CONDA_PREFIX/bin/pip install -e .
``` ```
For what you can do with the Llama CLI, please refer to [CLI Reference](./cli_reference.md). For what you can do with the Llama CLI, please refer to [CLI Reference](./cli_reference.md).
## Starting Up Llama Stack Server ## Starting Up Llama Stack Server
#### Starting up server via docker
We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links. You have two ways to start up Llama stack server:
- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
> [!NOTE] 1. **Starting up server via docker**:
> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
```
export LLAMA_CHECKPOINT_DIR=~/.llama
```
> [!NOTE] We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links.
> `~/.llama` should be the path containing downloaded weights of Llama models. - [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
To download llama models, use > [!NOTE]
``` > For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
llama download --model-id Llama3.1-8B-Instruct ```
``` export LLAMA_CHECKPOINT_DIR=~/.llama
```
To download and start running a pre-built docker container, you may use the following commands: > [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
``` To download llama models, use
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu ```
``` llama download --model-id Llama3.1-8B-Instruct
```
> [!TIP] To download and start running a pre-built docker container, you may use the following commands:
> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
#### Build->Configure->Run Llama Stack server via conda ```
You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack. docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu
```
**`llama stack build`** > [!TIP]
- You'll be prompted to enter build information interactively. > Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
```
llama stack build
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
> Enter the image type you want your distribution to be built with (docker or conda): conda
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. 2. **Build->Configure->Run Llama Stack server via conda**:
> Enter the API provider for the inference API: (default=meta-reference): meta-reference
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
> (Optional) Enter a short description for your Llama Stack distribution: You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack.
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml **`llama stack build`**
You can now run `llama stack configure my-local-stack` - You'll be prompted to enter build information interactively.
``` ```
llama stack build
**`llama stack configure`** > Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
- Run `llama stack configure <name>` with the name you have previously defined in `build` step. > Enter the image type you want your distribution to be built with (docker or conda): conda
```
llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
``` Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
$ llama stack configure my-local-stack > Enter the API provider for the inference API: (default=meta-reference): meta-reference
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
Could not find my-local-stack. Trying conda build name instead... > (Optional) Enter a short description for your Llama Stack distribution:
Configuring API `inference`...
=== Configuring provider `meta-reference` for API inference...
Enter value for model (default: Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional):
Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required):
Configuring API `safety`... Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
=== Configuring provider `meta-reference` for API safety... You can now run `llama stack configure my-local-stack`
Do you want to configure llama_guard_shield? (y/n): n ```
Do you want to configure prompt_guard_shield? (y/n): n
Configuring API `agents`... **`llama stack configure`**
=== Configuring provider `meta-reference` for API agents... - Run `llama stack configure <name>` with the name you have previously defined in `build` step.
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): ```
llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
Configuring SqliteKVStoreConfig: ```
Enter value for namespace (optional): $ llama stack configure my-local-stack
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
Configuring API `memory`... Could not find my-local-stack. Trying conda build name instead...
=== Configuring provider `meta-reference` for API memory... Configuring API `inference`...
> Please enter the supported memory bank type your provider has for memory: vector === Configuring provider `meta-reference` for API inference...
Enter value for model (default: Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional):
Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required):
Configuring API `telemetry`... Configuring API `safety`...
=== Configuring provider `meta-reference` for API telemetry... === Configuring provider `meta-reference` for API safety...
Do you want to configure llama_guard_shield? (y/n): n
Do you want to configure prompt_guard_shield? (y/n): n
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml. Configuring API `agents`...
You can now run `llama stack run my-local-stack --port PORT` === Configuring provider `meta-reference` for API agents...
``` Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
**`llama stack run`** Configuring SqliteKVStoreConfig:
- Run `llama stack run <name>` with the name you have previously defined. Enter value for namespace (optional):
``` Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
llama stack run my-local-stack
... Configuring API `memory`...
> initializing model parallel with size 1 === Configuring provider `meta-reference` for API memory...
> initializing ddp with size 1 > Please enter the supported memory bank type your provider has for memory: vector
> initializing pipeline with size 1
... Configuring API `telemetry`...
Finished model load YES READY === Configuring provider `meta-reference` for API telemetry...
Serving POST /inference/chat_completion
Serving POST /inference/completion > YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
Serving POST /inference/embeddings You can now run `llama stack run my-local-stack --port PORT`
Serving POST /memory_banks/create ```
Serving DELETE /memory_bank/documents/delete
Serving DELETE /memory_banks/drop **`llama stack run`**
Serving GET /memory_bank/documents/get - Run `llama stack run <name>` with the name you have previously defined.
Serving GET /memory_banks/get ```
Serving POST /memory_bank/insert llama stack run my-local-stack
Serving GET /memory_banks/list
Serving POST /memory_bank/query ...
Serving POST /memory_bank/update > initializing model parallel with size 1
Serving POST /safety/run_shield > initializing ddp with size 1
Serving POST /agentic_system/create > initializing pipeline with size 1
Serving POST /agentic_system/session/create ...
Serving POST /agentic_system/turn/create Finished model load YES READY
Serving POST /agentic_system/delete Serving POST /inference/chat_completion
Serving POST /agentic_system/session/delete Serving POST /inference/completion
Serving POST /agentic_system/session/get Serving POST /inference/embeddings
Serving POST /agentic_system/step/get Serving POST /memory_banks/create
Serving POST /agentic_system/turn/get Serving DELETE /memory_bank/documents/delete
Serving GET /telemetry/get_trace Serving DELETE /memory_banks/drop
Serving POST /telemetry/log_event Serving GET /memory_bank/documents/get
Listening on :::5000 Serving GET /memory_banks/get
INFO: Started server process [587053] Serving POST /memory_bank/insert
INFO: Waiting for application startup. Serving GET /memory_banks/list
INFO: Application startup complete. Serving POST /memory_bank/query
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) Serving POST /memory_bank/update
``` Serving POST /safety/run_shield
Serving POST /agentic_system/create
Serving POST /agentic_system/session/create
Serving POST /agentic_system/turn/create
Serving POST /agentic_system/delete
Serving POST /agentic_system/session/delete
Serving POST /agentic_system/session/get
Serving POST /agentic_system/step/get
Serving POST /agentic_system/turn/get
Serving GET /telemetry/get_trace
Serving POST /telemetry/log_event
Listening on :::5000
INFO: Started server process [587053]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
```
## Testing with client ## Testing with client

View file

@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
"model": self.map_to_provider_model(request.model), "model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request), **get_sampling_options(request.sampling_params),
} }
async def embeddings( async def embeddings(

View file

@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if prompt.startswith("<|begin_of_text|>"): if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :] prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request) options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512) options.setdefault("max_tokens", 512)
if fmt := request.response_format: if fmt := request.response_format:

View file

@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict: def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request) sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set # This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens. # for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None: if sampling_options["max_tokens"] is not None:
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return { return {
"model": OLLAMA_SUPPORTED_MODELS[request.model], "model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request), "options": get_sampling_options(request.sampling_params),
"raw": True, "raw": True,
"stream": request.stream, "stream": request.stream,
} }

View file

@ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info, chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
) )
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
@ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(
text=token_result.text, finish_reason=finish_reason
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion( async def chat_completion(
self, self,
@ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
prompt, input_tokens = chat_completion_request_to_model_input_info( prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter request, self.formatter
) )
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict( return dict(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
details=True, details=True,
max_new_tokens=max_new_tokens, max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options, **self._build_options(request.sampling_params, request.response_format),
) )
async def embeddings( async def embeddings(

View file

@ -131,7 +131,7 @@ class TogetherInferenceAdapter(
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request) options = get_sampling_options(request.sampling_params)
if fmt := request.response_format: if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = { options["response_format"] = {

View file

@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
"model": VLLM_SUPPORTED_MODELS[request.model], "model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request), **get_sampling_options(request.sampling_params),
} }
async def embeddings( async def embeddings(

View file

@ -137,6 +137,7 @@ async def test_completion(inference_settings):
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::ollama", "remote::ollama",
"remote::tgi",
): ):
pytest.skip("Other inference providers don't support completion() yet") pytest.skip("Other inference providers don't support completion() yet")
@ -170,6 +171,46 @@ async def test_completion(inference_settings):
assert last.stop_reason == StopReason.out_of_tokens assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonResponseFormat(
schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.parse_raw(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]

View file

@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice] choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict: def get_sampling_options(params: SamplingParams) -> dict:
options = {} options = {}
if params := request.sampling_params: if params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}: for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr): if getattr(params, attr):
options[attr] = getattr(params, attr) options[attr] = getattr(params, attr)
@ -64,7 +64,18 @@ def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse: ) -> CompletionResponse:
choice = response.choices[0] choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
)
return CompletionResponse( return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason), stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text, content=choice.text,
@ -95,13 +106,6 @@ async def process_completion_stream_response(
choice = chunk.choices[0] choice = chunk.choices[0]
finish_reason = choice.finish_reason finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice) text = text_from_choice(choice)
if text == "<|eot_id|>": if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
@ -115,6 +119,12 @@ async def process_completion_stream_response(
delta=text, delta=text,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta="",

View file

@ -31,6 +31,13 @@ def completion_request_to_prompt(
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
def chat_completion_request_to_prompt( def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat request: ChatCompletionRequest, formatter: ChatFormat
) -> str: ) -> str: