forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Addresses issue #679 - Adds support for the response_format field for chat completions and completions so users can get their outputs in JSON ## Test Plan <details> <summary>Integration tests</summary> `pytest llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output -k ollama -s -v` ```python llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-ollama] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_3b-ollama] PASSED ================================== 2 passed, 18 deselected, 3 warnings in 41.41s ================================== ``` </details> <details> <summary>Manual Tests</summary> ``` export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export OLLAMA_INFERENCE_MODEL=llama3.2:3b-instruct-fp16 export LLAMA_STACK_PORT=5000 ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m llama stack build --template ollama --image-type conda llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env OLLAMA_URL=http://localhost:11434 ``` ```python client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}") MODEL_ID=meta-llama/Llama-3.2-3B-Instruct prompt =f""" Create a step by step plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French. You have 3 different operations you can perform. You can create a file, update a file, or delete a file. Limit your step by step plan to only these operations per step. Don't create more than 10 steps. Please ensure there's a README.md file in the root of the codebase that describes the codebase and how to run it. Please ensure there's a requirements.txt file in the root of the codebase that describes the dependencies of the codebase. """ response = client.inference.chat_completion( model_id=MODEL_ID, messages=[ {"role": "user", "content": prompt}, ], sampling_params={ "max_tokens": 200000, }, response_format={ "type": "json_schema", "json_schema": { "$schema": "http://json-schema.org/draft-07/schema#", "title": "Plan", "description": f"A plan to complete the task of creating a codebase that is a web server that has an API endpoint that translates text from English to French.", "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "string" } } }, "required": ["steps"], "additionalProperties": False, } }, stream=True, ) content = "" for chunk in response: if chunk.event.delta: print(chunk.event.delta, end="", flush=True) content += chunk.event.delta try: plan = json.loads(content) print(plan) except Exception as e: print(f"Error parsing plan into JSON: {e}") plan = {"steps": []} ``` Outputs: ```json { "steps": [ "Update the requirements.txt file to include the updated dependencies specified in the peer's feedback, including the Google Cloud Translation API key.", "Update the app.py file to address the code smells and incorporate the suggested improvements, such as handling errors and exceptions, initializing the Translator object correctly, adding input validation, using type hints and docstrings, and removing unnecessary logging statements.", "Create a README.md file that describes the codebase and how to run it.", "Ensure the README.md file is up-to-date and accurate.", "Update the requirements.txt file to reflect any additional dependencies specified by the peer's feedback.", "Add documentation for each function in the app.py file using docstrings.", "Implement logging statements throughout the app.py file to monitor application execution.", "Test the API endpoint to ensure it correctly translates text from English to French and handles errors properly.", "Refactor the code to follow PEP 8 style guidelines and ensure consistency in naming conventions, indentation, and spacing.", "Create a new folder for logs and add a logging configuration file (e.g., logconfig.json) that specifies the logging level and output destination.", "Deploy the web server on a production environment (e.g., AWS Elastic Beanstalk or Google Cloud Platform) to make it accessible to external users." ] } ``` </details> ## Sources - Ollama api docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion - Ollama structured output docs: https://github.com/ollama/ollama/blob/main/docs/api.md#request-structured-outputs ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests.
416 lines
14 KiB
Python
416 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import logging
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
import httpx
|
|
from llama_models.datatypes import CoreModelId
|
|
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
from ollama import AsyncClient
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
CompletionRequest,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
ToolChoice,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
build_model_alias,
|
|
build_model_alias_with_just_provider_model_id,
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
get_sampling_options,
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
convert_image_content_to_url,
|
|
interleaved_content_as_str,
|
|
request_has_media,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
model_aliases = [
|
|
build_model_alias(
|
|
"llama3.1:8b-instruct-fp16",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.1:8b",
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.1:70b-instruct-fp16",
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.1:70b",
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.1:405b-instruct-fp16",
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.1:405b",
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.2:1b-instruct-fp16",
|
|
CoreModelId.llama3_2_1b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.2:1b",
|
|
CoreModelId.llama3_2_1b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.2:3b-instruct-fp16",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.2:3b",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.2-vision:11b-instruct-fp16",
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.2-vision:latest",
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.2-vision:90b-instruct-fp16",
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
),
|
|
build_model_alias_with_just_provider_model_id(
|
|
"llama3.2-vision:90b",
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
),
|
|
build_model_alias(
|
|
"llama3.3:70b",
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
),
|
|
# The Llama Guard models don't have their full fp16 versions
|
|
# so we are going to alias their default version to the canonical SKU
|
|
build_model_alias(
|
|
"llama-guard3:8b",
|
|
CoreModelId.llama_guard_3_8b.value,
|
|
),
|
|
build_model_alias(
|
|
"llama-guard3:1b",
|
|
CoreModelId.llama_guard_3_1b.value,
|
|
),
|
|
]
|
|
|
|
|
|
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|
def __init__(self, url: str) -> None:
|
|
self.register_helper = ModelRegistryHelper(model_aliases)
|
|
self.url = url
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
@property
|
|
def client(self) -> AsyncClient:
|
|
return AsyncClient(host=self.url)
|
|
|
|
async def initialize(self) -> None:
|
|
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
|
try:
|
|
await self.client.ps()
|
|
except httpx.ConnectError as e:
|
|
raise RuntimeError(
|
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
|
) from e
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
|
yield chunk
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
r = await self.client.generate(**params)
|
|
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
return process_completion_response(response, self.formatter)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
response_format=response_format,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _get_params(
|
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
) -> dict:
|
|
sampling_options = get_sampling_options(request.sampling_params)
|
|
# This is needed since the Ollama API expects num_predict to be set
|
|
# for early truncation instead of max_tokens.
|
|
if sampling_options.get("max_tokens") is not None:
|
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
|
|
input_dict = {}
|
|
media_present = request_has_media(request)
|
|
if isinstance(request, ChatCompletionRequest):
|
|
if media_present:
|
|
contents = [
|
|
await convert_message_to_openai_dict_for_ollama(m)
|
|
for m in request.messages
|
|
]
|
|
# flatten the list of lists
|
|
input_dict["messages"] = [
|
|
item for sublist in contents for item in sublist
|
|
]
|
|
else:
|
|
input_dict["raw"] = True
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
|
request,
|
|
self.register_helper.get_llama_model(request.model),
|
|
self.formatter,
|
|
)
|
|
else:
|
|
assert (
|
|
not media_present
|
|
), "Ollama does not support media for Completion requests"
|
|
input_dict["prompt"] = await completion_request_to_prompt(
|
|
request, self.formatter
|
|
)
|
|
input_dict["raw"] = True
|
|
|
|
if fmt := request.response_format:
|
|
if fmt.type == "json_schema":
|
|
input_dict["format"] = fmt.json_schema
|
|
elif fmt.type == "grammar":
|
|
raise NotImplementedError("Grammar response format is not supported")
|
|
else:
|
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
|
|
|
return {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"options": sampling_options,
|
|
"stream": request.stream,
|
|
}
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
if "messages" in params:
|
|
r = await self.client.chat(**params)
|
|
else:
|
|
r = await self.client.generate(**params)
|
|
|
|
if "message" in r:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
return process_chat_completion_response(response, self.formatter)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
if "messages" in params:
|
|
s = await self.client.chat(**params)
|
|
else:
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
if "message" in chunk:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(
|
|
stream, self.formatter
|
|
):
|
|
yield chunk
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
model = await self.model_store.get_model(model_id)
|
|
|
|
assert all(
|
|
not content_has_media(content) for content in contents
|
|
), "Ollama does not support media for embeddings"
|
|
response = await self.client.embed(
|
|
model=model.provider_resource_id,
|
|
input=[interleaved_content_as_str(content) for content in contents],
|
|
)
|
|
embeddings = response["embeddings"]
|
|
|
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
|
if model.model_type == ModelType.embedding:
|
|
response = await self.client.list()
|
|
available_models = [m["model"] for m in response["models"]]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
|
f"Available models: {', '.join(available_models)}"
|
|
)
|
|
return model
|
|
model = await self.register_helper.register_model(model)
|
|
models = await self.client.ps()
|
|
available_models = [m["model"] for m in models["models"]]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
|
f"Available models: {', '.join(available_models)}"
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
return {
|
|
"role": message.role,
|
|
"images": [
|
|
await convert_image_content_to_url(
|
|
content, download=True, include_format=False
|
|
)
|
|
],
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {
|
|
"role": message.role,
|
|
"content": text,
|
|
}
|
|
|
|
if isinstance(message.content, list):
|
|
return [await _convert_content(c) for c in message.content]
|
|
else:
|
|
return [await _convert_content(message.content)]
|