Merge branch 'main' into add-nvidia-inference-adapter

This commit is contained in:
Matthew Farrellee 2024-11-20 09:37:48 -05:00
commit 8a35dc8b0e
28 changed files with 429 additions and 478 deletions

View file

@ -50,11 +50,11 @@ MODEL_ALIASES = [
),
build_model_alias(
"fireworks/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-11b-vision-instruct",
@ -214,10 +214,10 @@ class FireworksInferenceAdapter(
async def _to_async_generator():
if "messages" in params:
stream = await self._get_client().chat.completions.acreate(**params)
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.create(**params)
for chunk in stream:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()

View file

@ -264,6 +264,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
print(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -53,6 +53,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
print(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None: