llama-stack/llama_stack/providers/remote/inference/ollama/ollama.py
Kai Wu b6500974ec
removed assertion in ollama.py and fixed typo in the readme (#563)
# What does this PR do?
1. removed [incorrect
assertion](435f34b05e/llama_stack/providers/remote/inference/ollama/ollama.py (L183))
in ollama.py
2. fixed a typo in [this
line](435f34b05e/docs/source/distributions/importing_as_library.md (L24)),
as `model=` should be `model_id=` .

- [x] Addresses issue
([#issue562](https://github.com/meta-llama/llama-stack/issues/562))


## Test Plan

tested with code:

```python
import asyncio
import os

# pip install aiosqlite ollama faiss
from llama_stack_client.lib.direct.direct import LlamaStackDirectClient
from llama_stack_client.types import SystemMessage, UserMessage


async def main():
    os.environ["INFERENCE_MODEL"] = "meta-llama/Llama-3.2-1B-Instruct"
    client = await LlamaStackDirectClient.from_template("ollama")
    await client.initialize()
    response = await client.models.list()
    print(response)
    model_name = response[0].identifier
    response = await client.inference.chat_completion(
        messages=[
            SystemMessage(content="You are a friendly assistant.", role="system"),
            UserMessage(
                content="hello world, write me a 2 sentence poem about the moon",
                role="user",
            ),
        ],
        model_id=model_name,
        stream=False,
    )
    print("\nChat completion response:")
    print(response, type(response))


asyncio.run(main())

```
OUTPUT:
```
python test.py
Using template ollama with config:
apis:
- agents
- inference
- memory
- safety
- telemetry
conda_env: ollama
datasets: []
docker_image: null
eval_tasks: []
image_name: ollama
memory_banks: []
metadata_store:
  db_path: /Users/kaiwu/.llama/distributions/ollama/registry.db
  namespace: null
  type: sqlite
models:
- metadata: {}
  model_id: meta-llama/Llama-3.2-1B-Instruct
  provider_id: ollama
  provider_model_id: null
providers:
  agents:
  - config:
      persistence_store:
        db_path:
/Users/kaiwu/.llama/distributions/ollama/agents_store.db
        namespace: null
        type: sqlite
    provider_id: meta-reference
    provider_type: inline::meta-reference
  inference:
  - config:
      url: http://localhost:11434
    provider_id: ollama
    provider_type: remote::ollama
  memory:
  - config:
      kvstore:
        db_path:
/Users/kaiwu/.llama/distributions/ollama/faiss_store.db
        namespace: null
        type: sqlite
    provider_id: faiss
    provider_type: inline::faiss
  safety:
  - config: {}
    provider_id: llama-guard
    provider_type: inline::llama-guard
  telemetry:
  - config: {}
    provider_id: meta-reference
    provider_type: inline::meta-reference
scoring_fns: []
shields: []
version: '2'

[Model(identifier='meta-llama/Llama-3.2-1B-Instruct', provider_resource_id='llama3.2:1b-instruct-fp16', provider_id='ollama', type='model', metadata={})]

Chat completion response:
completion_message=CompletionMessage(role='assistant', content='Here is a short poem about the moon:\n\nThe moon glows bright in the midnight sky,\nA silver crescent shining, catching the eye.', stop_reason=<StopReason.end_of_turn: 'end_of_turn'>, tool_calls=[]) logprobs=None <class 'llama_stack.apis.inference.inference.ChatCompletionResponse'>
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-03 20:11:19 -08:00

360 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_image_media_to_url,
request_has_media,
)
log = logging.getLogger(__name__)
model_aliases = [
build_model_alias(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_model_alias(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
r = await self.client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_dict_for_ollama(m)
for m in request.messages
]
# flatten the list of lists
input_dict["messages"] = [
item for sublist in contents for item in sublist
]
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert (
not media_present
), "Ollama does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["raw"] = True
return {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
else:
s = await self.client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"role": message.role,
"images": [
await convert_image_media_to_url(
content, download=True, include_format=False
)
],
}
else:
return {
"role": message.role,
"content": content,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]