mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge branch 'main' into use-openai-for-anthropic
This commit is contained in:
commit
1325d4b1e5
4 changed files with 13 additions and 134 deletions
|
@ -248,7 +248,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="groq",
|
adapter_type="groq",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.groq",
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
|
|
|
@ -4,30 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAIChoiceDelta,
|
|
||||||
OpenAIChunkChoice,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
OpenAISystemMessageParam,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
prepare_openai_completion_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
_config: GroqConfig
|
_config: GroqConfig
|
||||||
|
|
||||||
def __init__(self, config: GroqConfig):
|
def __init__(self, config: GroqConfig):
|
||||||
|
@ -40,122 +25,14 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
return f"{self.config.url}/openai/v1"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
|
||||||
return AsyncOpenAI(
|
|
||||||
base_url=f"{self.config.url}/openai/v1",
|
|
||||||
api_key=self.get_api_key(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
|
||||||
if response_format and response_format.type == "json_schema":
|
|
||||||
response_format.type = "json_object"
|
|
||||||
schema = response_format.json_schema.get("schema", {})
|
|
||||||
response_format.json_schema = None
|
|
||||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
|
||||||
if messages and messages[0].role == "system":
|
|
||||||
messages[0].content = messages[0].content + json_instructions
|
|
||||||
else:
|
|
||||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
|
||||||
|
|
||||||
# Groq returns a 400 error if tools are provided but none are called
|
|
||||||
# So, set tool_choice to "required" to attempt to force a call
|
|
||||||
if tools and (not tool_choice or tool_choice == "auto"):
|
|
||||||
tool_choice = "required"
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Groq does not support streaming requests that set response_format
|
|
||||||
fake_stream = False
|
|
||||||
if stream and response_format:
|
|
||||||
params["stream"] = False
|
|
||||||
fake_stream = True
|
|
||||||
|
|
||||||
response = await self._get_openai_client().chat.completions.create(**params)
|
|
||||||
|
|
||||||
if fake_stream:
|
|
||||||
chunk_choices = []
|
|
||||||
for choice in response.choices:
|
|
||||||
delta = OpenAIChoiceDelta(
|
|
||||||
content=choice.message.content,
|
|
||||||
role=choice.message.role,
|
|
||||||
tool_calls=choice.message.tool_calls,
|
|
||||||
)
|
|
||||||
chunk_choice = OpenAIChunkChoice(
|
|
||||||
delta=delta,
|
|
||||||
finish_reason=choice.finish_reason,
|
|
||||||
index=choice.index,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
chunk_choices.append(chunk_choice)
|
|
||||||
chunk = OpenAIChatCompletionChunk(
|
|
||||||
id=response.id,
|
|
||||||
choices=chunk_choices,
|
|
||||||
object="chat.completion.chunk",
|
|
||||||
created=response.created,
|
|
||||||
model=response.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fake_stream_generator():
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return _fake_stream_generator()
|
|
||||||
else:
|
|
||||||
return response
|
|
||||||
|
|
|
@ -37,6 +37,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
||||||
"remote::sambanova",
|
"remote::sambanova",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::vertexai",
|
"remote::vertexai",
|
||||||
|
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
|
||||||
|
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
|
||||||
|
"remote::groq",
|
||||||
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
|
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
|
||||||
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
|
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
|
||||||
):
|
):
|
||||||
|
|
|
@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
openai_client = inference_adapter._get_openai_client()
|
assert inference_adapter.client.api_key == api_key
|
||||||
assert openai_client.api_key == api_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_openai_provider_openai_client_caching():
|
def test_openai_provider_openai_client_caching():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue