llama-stack/llama_stack/providers/remote/inference/databricks/databricks.py
Ashwin Bharambe 8de8eb03c8
Update the "InterleavedTextMedia" type (#635)
## What does this PR do?

This is a long-pending change and particularly important to get done
now.

Specifically:
- we cannot "localize" (aka download) any URLs from media attachments
anywhere near our modeling code. it must be done within llama-stack.
- `PIL.Image` is infesting all our APIs via `ImageMedia ->
InterleavedTextMedia` and that cannot be right at all. Anything in the
API surface must be "naturally serializable". We need a standard `{
type: "image", image_url: "<...>" }` which is more extensible
- `UserMessage`, `SystemMessage`, etc. are moved completely to
llama-stack from the llama-models repository.

See https://github.com/meta-llama/llama-models/pull/244 for the
corresponding PR in llama-models.

## Test Plan

```bash
cd llama_stack/providers/tests

pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py
pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py
pytest -s -v -k chroma memory/test_memory.py \
  --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar

pytest -s -v -k fireworks agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct
```

Updated the client sdk (see PR ...), installed the SDK in the same
environment and then ran the SDK tests:

```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py
LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py

# this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly
INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py
```
2024-12-17 11:18:31 -08:00

140 lines
4.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import DatabricksImplConfig
model_aliases = [
build_model_alias(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
]
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()