forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? 1. removed [incorrect assertion](435f34b05e/llama_stack/providers/remote/inference/ollama/ollama.py (L183)) in ollama.py 2. fixed a typo in [this line](435f34b05e/docs/source/distributions/importing_as_library.md (L24)), as `model=` should be `model_id=` . - [x] Addresses issue ([#issue562](https://github.com/meta-llama/llama-stack/issues/562)) ## Test Plan tested with code: ```python import asyncio import os # pip install aiosqlite ollama faiss from llama_stack_client.lib.direct.direct import LlamaStackDirectClient from llama_stack_client.types import SystemMessage, UserMessage async def main(): os.environ["INFERENCE_MODEL"] = "meta-llama/Llama-3.2-1B-Instruct" client = await LlamaStackDirectClient.from_template("ollama") await client.initialize() response = await client.models.list() print(response) model_name = response[0].identifier response = await client.inference.chat_completion( messages=[ SystemMessage(content="You are a friendly assistant.", role="system"), UserMessage( content="hello world, write me a 2 sentence poem about the moon", role="user", ), ], model_id=model_name, stream=False, ) print("\nChat completion response:") print(response, type(response)) asyncio.run(main()) ``` OUTPUT: ``` python test.py Using template ollama with config: apis: - agents - inference - memory - safety - telemetry conda_env: ollama datasets: [] docker_image: null eval_tasks: [] image_name: ollama memory_banks: [] metadata_store: db_path: /Users/kaiwu/.llama/distributions/ollama/registry.db namespace: null type: sqlite models: - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: ollama provider_model_id: null providers: agents: - config: persistence_store: db_path: /Users/kaiwu/.llama/distributions/ollama/agents_store.db namespace: null type: sqlite provider_id: meta-reference provider_type: inline::meta-reference inference: - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama memory: - config: kvstore: db_path: /Users/kaiwu/.llama/distributions/ollama/faiss_store.db namespace: null type: sqlite provider_id: faiss provider_type: inline::faiss safety: - config: {} provider_id: llama-guard provider_type: inline::llama-guard telemetry: - config: {} provider_id: meta-reference provider_type: inline::meta-reference scoring_fns: [] shields: [] version: '2' [Model(identifier='meta-llama/Llama-3.2-1B-Instruct', provider_resource_id='llama3.2:1b-instruct-fp16', provider_id='ollama', type='model', metadata={})] Chat completion response: completion_message=CompletionMessage(role='assistant', content='Here is a short poem about the moon:\n\nThe moon glows bright in the midnight sky,\nA silver crescent shining, catching the eye.', stop_reason=<StopReason.end_of_turn: 'end_of_turn'>, tool_calls=[]) logprobs=None <class 'llama_stack.apis.inference.inference.ChatCompletionResponse'> ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
		
			
				
	
	
		
			360 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			360 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import logging
 | |
| from typing import AsyncGenerator
 | |
| 
 | |
| import httpx
 | |
| from llama_models.datatypes import CoreModelId
 | |
| 
 | |
| from llama_models.llama3.api.chat_format import ChatFormat
 | |
| from llama_models.llama3.api.datatypes import Message
 | |
| from llama_models.llama3.api.tokenizer import Tokenizer
 | |
| from ollama import AsyncClient
 | |
| 
 | |
| from llama_stack.providers.utils.inference.model_registry import (
 | |
|     build_model_alias,
 | |
|     build_model_alias_with_just_provider_model_id,
 | |
|     ModelRegistryHelper,
 | |
| )
 | |
| 
 | |
| from llama_stack.apis.inference import *  # noqa: F403
 | |
| from llama_stack.providers.datatypes import ModelsProtocolPrivate
 | |
| 
 | |
| from llama_stack.providers.utils.inference.openai_compat import (
 | |
|     get_sampling_options,
 | |
|     OpenAICompatCompletionChoice,
 | |
|     OpenAICompatCompletionResponse,
 | |
|     process_chat_completion_response,
 | |
|     process_chat_completion_stream_response,
 | |
|     process_completion_response,
 | |
|     process_completion_stream_response,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     chat_completion_request_to_prompt,
 | |
|     completion_request_to_prompt,
 | |
|     convert_image_media_to_url,
 | |
|     request_has_media,
 | |
| )
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| model_aliases = [
 | |
|     build_model_alias(
 | |
|         "llama3.1:8b-instruct-fp16",
 | |
|         CoreModelId.llama3_1_8b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.1:8b",
 | |
|         CoreModelId.llama3_1_8b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.1:70b-instruct-fp16",
 | |
|         CoreModelId.llama3_1_70b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.1:70b",
 | |
|         CoreModelId.llama3_1_70b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.1:405b-instruct-fp16",
 | |
|         CoreModelId.llama3_1_405b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.1:405b",
 | |
|         CoreModelId.llama3_1_405b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.2:1b-instruct-fp16",
 | |
|         CoreModelId.llama3_2_1b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.2:1b",
 | |
|         CoreModelId.llama3_2_1b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.2:3b-instruct-fp16",
 | |
|         CoreModelId.llama3_2_3b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.2:3b",
 | |
|         CoreModelId.llama3_2_3b_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.2-vision:11b-instruct-fp16",
 | |
|         CoreModelId.llama3_2_11b_vision_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.2-vision",
 | |
|         CoreModelId.llama3_2_11b_vision_instruct.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama3.2-vision:90b-instruct-fp16",
 | |
|         CoreModelId.llama3_2_90b_vision_instruct.value,
 | |
|     ),
 | |
|     build_model_alias_with_just_provider_model_id(
 | |
|         "llama3.2-vision:90b",
 | |
|         CoreModelId.llama3_2_90b_vision_instruct.value,
 | |
|     ),
 | |
|     # The Llama Guard models don't have their full fp16 versions
 | |
|     # so we are going to alias their default version to the canonical SKU
 | |
|     build_model_alias(
 | |
|         "llama-guard3:8b",
 | |
|         CoreModelId.llama_guard_3_8b.value,
 | |
|     ),
 | |
|     build_model_alias(
 | |
|         "llama-guard3:1b",
 | |
|         CoreModelId.llama_guard_3_1b.value,
 | |
|     ),
 | |
| ]
 | |
| 
 | |
| 
 | |
| class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
 | |
|     def __init__(self, url: str) -> None:
 | |
|         self.register_helper = ModelRegistryHelper(model_aliases)
 | |
|         self.url = url
 | |
|         self.formatter = ChatFormat(Tokenizer.get_instance())
 | |
| 
 | |
|     @property
 | |
|     def client(self) -> AsyncClient:
 | |
|         return AsyncClient(host=self.url)
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         log.info(f"checking connectivity to Ollama at `{self.url}`...")
 | |
|         try:
 | |
|             await self.client.ps()
 | |
|         except httpx.ConnectError as e:
 | |
|             raise RuntimeError(
 | |
|                 "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
 | |
|             ) from e
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def unregister_model(self, model_id: str) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedTextMedia,
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         request = CompletionRequest(
 | |
|             model=model.provider_resource_id,
 | |
|             content=content,
 | |
|             sampling_params=sampling_params,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         if stream:
 | |
|             return self._stream_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_completion(request)
 | |
| 
 | |
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         params = await self._get_params(request)
 | |
| 
 | |
|         async def _generate_and_convert_to_openai_compat():
 | |
|             s = await self.client.generate(**params)
 | |
|             async for chunk in s:
 | |
|                 choice = OpenAICompatCompletionChoice(
 | |
|                     finish_reason=chunk["done_reason"] if chunk["done"] else None,
 | |
|                     text=chunk["response"],
 | |
|                 )
 | |
|                 yield OpenAICompatCompletionResponse(
 | |
|                     choices=[choice],
 | |
|                 )
 | |
| 
 | |
|         stream = _generate_and_convert_to_openai_compat()
 | |
|         async for chunk in process_completion_stream_response(stream, self.formatter):
 | |
|             yield chunk
 | |
| 
 | |
|     async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         params = await self._get_params(request)
 | |
|         r = await self.client.generate(**params)
 | |
| 
 | |
|         choice = OpenAICompatCompletionChoice(
 | |
|             finish_reason=r["done_reason"] if r["done"] else None,
 | |
|             text=r["response"],
 | |
|         )
 | |
|         response = OpenAICompatCompletionResponse(
 | |
|             choices=[choice],
 | |
|         )
 | |
| 
 | |
|         return process_completion_response(response, self.formatter)
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: List[Message],
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         tools: Optional[List[ToolDefinition]] = None,
 | |
|         tool_choice: Optional[ToolChoice] = ToolChoice.auto,
 | |
|         tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         model = await self.model_store.get_model(model_id)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model.provider_resource_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             tool_choice=tool_choice,
 | |
|             tool_prompt_format=tool_prompt_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         if stream:
 | |
|             return self._stream_chat_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_chat_completion(request)
 | |
| 
 | |
|     async def _get_params(
 | |
|         self, request: Union[ChatCompletionRequest, CompletionRequest]
 | |
|     ) -> dict:
 | |
|         sampling_options = get_sampling_options(request.sampling_params)
 | |
|         # This is needed since the Ollama API expects num_predict to be set
 | |
|         # for early truncation instead of max_tokens.
 | |
|         if sampling_options.get("max_tokens") is not None:
 | |
|             sampling_options["num_predict"] = sampling_options["max_tokens"]
 | |
| 
 | |
|         input_dict = {}
 | |
|         media_present = request_has_media(request)
 | |
|         if isinstance(request, ChatCompletionRequest):
 | |
|             if media_present:
 | |
|                 contents = [
 | |
|                     await convert_message_to_dict_for_ollama(m)
 | |
|                     for m in request.messages
 | |
|                 ]
 | |
|                 # flatten the list of lists
 | |
|                 input_dict["messages"] = [
 | |
|                     item for sublist in contents for item in sublist
 | |
|                 ]
 | |
|             else:
 | |
|                 input_dict["raw"] = True
 | |
|                 input_dict["prompt"] = chat_completion_request_to_prompt(
 | |
|                     request,
 | |
|                     self.register_helper.get_llama_model(request.model),
 | |
|                     self.formatter,
 | |
|                 )
 | |
|         else:
 | |
|             assert (
 | |
|                 not media_present
 | |
|             ), "Ollama does not support media for Completion requests"
 | |
|             input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
 | |
|             input_dict["raw"] = True
 | |
| 
 | |
|         return {
 | |
|             "model": request.model,
 | |
|             **input_dict,
 | |
|             "options": sampling_options,
 | |
|             "stream": request.stream,
 | |
|         }
 | |
| 
 | |
|     async def _nonstream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> ChatCompletionResponse:
 | |
|         params = await self._get_params(request)
 | |
|         if "messages" in params:
 | |
|             r = await self.client.chat(**params)
 | |
|         else:
 | |
|             r = await self.client.generate(**params)
 | |
|         assert isinstance(r, dict)
 | |
| 
 | |
|         if "message" in r:
 | |
|             choice = OpenAICompatCompletionChoice(
 | |
|                 finish_reason=r["done_reason"] if r["done"] else None,
 | |
|                 text=r["message"]["content"],
 | |
|             )
 | |
|         else:
 | |
|             choice = OpenAICompatCompletionChoice(
 | |
|                 finish_reason=r["done_reason"] if r["done"] else None,
 | |
|                 text=r["response"],
 | |
|             )
 | |
|         response = OpenAICompatCompletionResponse(
 | |
|             choices=[choice],
 | |
|         )
 | |
|         return process_chat_completion_response(response, self.formatter)
 | |
| 
 | |
|     async def _stream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> AsyncGenerator:
 | |
|         params = await self._get_params(request)
 | |
| 
 | |
|         async def _generate_and_convert_to_openai_compat():
 | |
|             if "messages" in params:
 | |
|                 s = await self.client.chat(**params)
 | |
|             else:
 | |
|                 s = await self.client.generate(**params)
 | |
|             async for chunk in s:
 | |
|                 if "message" in chunk:
 | |
|                     choice = OpenAICompatCompletionChoice(
 | |
|                         finish_reason=chunk["done_reason"] if chunk["done"] else None,
 | |
|                         text=chunk["message"]["content"],
 | |
|                     )
 | |
|                 else:
 | |
|                     choice = OpenAICompatCompletionChoice(
 | |
|                         finish_reason=chunk["done_reason"] if chunk["done"] else None,
 | |
|                         text=chunk["response"],
 | |
|                     )
 | |
|                 yield OpenAICompatCompletionResponse(
 | |
|                     choices=[choice],
 | |
|                 )
 | |
| 
 | |
|         stream = _generate_and_convert_to_openai_compat()
 | |
|         async for chunk in process_chat_completion_stream_response(
 | |
|             stream, self.formatter
 | |
|         ):
 | |
|             yield chunk
 | |
| 
 | |
|     async def embeddings(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         contents: List[InterleavedTextMedia],
 | |
|     ) -> EmbeddingsResponse:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     async def register_model(self, model: Model) -> Model:
 | |
|         model = await self.register_helper.register_model(model)
 | |
|         models = await self.client.ps()
 | |
|         available_models = [m["model"] for m in models["models"]]
 | |
|         if model.provider_resource_id not in available_models:
 | |
|             raise ValueError(
 | |
|                 f"Model '{model.provider_resource_id}' is not available in Ollama. "
 | |
|                 f"Available models: {', '.join(available_models)}"
 | |
|             )
 | |
| 
 | |
|         return model
 | |
| 
 | |
| 
 | |
| async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
 | |
|     async def _convert_content(content) -> dict:
 | |
|         if isinstance(content, ImageMedia):
 | |
|             return {
 | |
|                 "role": message.role,
 | |
|                 "images": [
 | |
|                     await convert_image_media_to_url(
 | |
|                         content, download=True, include_format=False
 | |
|                     )
 | |
|                 ],
 | |
|             }
 | |
|         else:
 | |
|             return {
 | |
|                 "role": message.role,
 | |
|                 "content": content,
 | |
|             }
 | |
| 
 | |
|     if isinstance(message.content, list):
 | |
|         return [await _convert_content(c) for c in message.content]
 | |
|     else:
 | |
|         return [await _convert_content(message.content)]
 |