feat: simplify remote providers and tool use capability (#4970)
* feat: built-in remote providers go to tokenjs * fix: error handling * fix: extend models * chore: error handling * chore: update advanced settings of built-in providers * chore: clean up message creation * chore: fix import * fix: engine name * fix: error handling
This commit is contained in:
parent
7748f0c7e1
commit
f3a808cb89
@ -5,7 +5,7 @@
|
||||
"url": "https://aistudio.google.com/apikey",
|
||||
"api_key": "",
|
||||
"metadata": {
|
||||
"get_models_url": "https://generativelanguage.googleapis.com/v1beta/models",
|
||||
"get_models_url": "https://generativelanguage.googleapis.com/openai/v1beta/models",
|
||||
"header_template": "Authorization: Bearer {{api_key}}",
|
||||
"transform_req": {
|
||||
"chat_completions": {
|
||||
|
||||
@ -107,7 +107,6 @@ mod tests {
|
||||
use super::*;
|
||||
use std::fs::{self, File};
|
||||
use std::io::Write;
|
||||
use serde_json::to_string;
|
||||
use tauri::test::mock_app;
|
||||
|
||||
#[test]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import 'openai/shims/web'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
|
||||
import {
|
||||
@ -18,11 +19,66 @@ import { useAtom, useAtomValue } from 'jotai'
|
||||
import { atomWithStorage } from 'jotai/utils'
|
||||
import useSWR from 'swr'
|
||||
|
||||
import { models, TokenJS } from 'token.js'
|
||||
import { LLMProvider } from 'token.js/dist/chat'
|
||||
|
||||
import { getDescriptionByEngine, getTitleByEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
export const builtInEngines = [
|
||||
'openai',
|
||||
'ai21',
|
||||
'anthropic',
|
||||
'gemini',
|
||||
'cohere',
|
||||
'bedrock',
|
||||
'mistral',
|
||||
'groq',
|
||||
'perplexity',
|
||||
'openrouter',
|
||||
'openai-compatible',
|
||||
]
|
||||
|
||||
export const convertBuiltInEngine = (engine?: string): LLMProvider => {
|
||||
const engineName = normalizeBuiltInEngineName(engine) ?? ''
|
||||
return (
|
||||
builtInEngines.includes(engineName) ? engineName : 'openai-compatible'
|
||||
) as LLMProvider
|
||||
}
|
||||
|
||||
export const normalizeBuiltInEngineName = (
|
||||
engine?: string
|
||||
): string | undefined => {
|
||||
return engine === ('google_gemini' as InferenceEngine) ? 'gemini' : engine
|
||||
}
|
||||
|
||||
export const extendBuiltInEngineModels = (
|
||||
tokenJS: TokenJS,
|
||||
provider: LLMProvider,
|
||||
model?: string
|
||||
) => {
|
||||
if (provider !== 'openrouter' && provider !== 'openai-compatible' && model) {
|
||||
if (
|
||||
provider in Object.keys(models) &&
|
||||
(models[provider].models as unknown as string[]).includes(model)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// @ts-expect-error Unknown extendModelList provider type
|
||||
tokenJS.extendModelList(provider, model, {
|
||||
streaming: true,
|
||||
toolCalls: true,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Failed to extend model list:', error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const releasedEnginesCacheAtom = atomWithStorage<{
|
||||
data: EngineReleased[]
|
||||
timestamp: number
|
||||
|
||||
@ -48,7 +48,7 @@ export default function useFactoryReset() {
|
||||
|
||||
// 2: Delete the old jan data folder
|
||||
setFactoryResetState(FactoryResetState.DeletingData)
|
||||
await fs.rm({ args: [janDataFolderPath] })
|
||||
await fs.rm(janDataFolderPath)
|
||||
|
||||
// 3: Set the default jan data folder
|
||||
if (!keepCurrentFolder) {
|
||||
@ -61,6 +61,8 @@ export default function useFactoryReset() {
|
||||
await window.core?.api?.updateAppConfiguration({ configuration })
|
||||
}
|
||||
|
||||
await window.core?.api?.installExtensions()
|
||||
|
||||
// Perform factory reset
|
||||
// await window.core?.api?.factoryReset()
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import 'openai/shims/web'
|
||||
|
||||
import { useEffect, useRef } from 'react'
|
||||
|
||||
import {
|
||||
@ -18,16 +20,19 @@ import {
|
||||
} from '@janhq/core'
|
||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { OpenAI } from 'openai'
|
||||
|
||||
import {
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionRole as OpenAIChatCompletionRole,
|
||||
ChatCompletionTool,
|
||||
ChatCompletionMessageToolCall,
|
||||
} from 'openai/resources/chat'
|
||||
|
||||
import { Stream } from 'openai/streaming'
|
||||
import {
|
||||
CompletionResponse,
|
||||
StreamCompletionResponse,
|
||||
TokenJS,
|
||||
models,
|
||||
} from 'token.js'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
|
||||
@ -38,12 +43,23 @@ import {
|
||||
} from '@/containers/Providers/Jotai'
|
||||
|
||||
import { compressImage, getBase64 } from '@/utils/base64'
|
||||
import {
|
||||
createMessage,
|
||||
createMessageContent,
|
||||
emptyMessageContent,
|
||||
} from '@/utils/createMessage'
|
||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||
|
||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
|
||||
import { useActiveModel } from './useActiveModel'
|
||||
|
||||
import {
|
||||
convertBuiltInEngine,
|
||||
extendBuiltInEngineModels,
|
||||
useGetEngines,
|
||||
} from './useEngineManagement'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import {
|
||||
@ -100,6 +116,8 @@ export default function useSendChatMessage(
|
||||
|
||||
const selectedModelRef = useRef<Model | undefined>()
|
||||
|
||||
const { engines } = useGetEngines()
|
||||
|
||||
useEffect(() => {
|
||||
modelRef.current = activeModel
|
||||
}, [activeModel])
|
||||
@ -167,6 +185,7 @@ export default function useSendChatMessage(
|
||||
setCurrentPrompt('')
|
||||
setEditPrompt('')
|
||||
|
||||
try {
|
||||
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
|
||||
|
||||
if (base64Blob && fileUpload?.type === 'image') {
|
||||
@ -265,11 +284,19 @@ export default function useSendChatMessage(
|
||||
|
||||
if (requestBuilder.tools && requestBuilder.tools.length) {
|
||||
let isDone = false
|
||||
const openai = new OpenAI({
|
||||
apiKey: await window.core.api.appToken(),
|
||||
baseURL: `${API_BASE_URL}/v1`,
|
||||
dangerouslyAllowBrowser: true,
|
||||
|
||||
const engine =
|
||||
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
|
||||
const apiKey = engine?.api_key
|
||||
const provider = convertBuiltInEngine(engine?.engine)
|
||||
|
||||
const tokenJS = new TokenJS({
|
||||
apiKey: apiKey ?? (await window.core.api.appToken()),
|
||||
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
|
||||
})
|
||||
|
||||
extendBuiltInEngineModels(tokenJS, provider, modelId)
|
||||
|
||||
let parentMessageId: string | undefined
|
||||
while (!isDone) {
|
||||
let messageId = ulid()
|
||||
@ -278,51 +305,57 @@ export default function useSendChatMessage(
|
||||
messageId = parentMessageId
|
||||
}
|
||||
const data = requestBuilder.build()
|
||||
const message: ThreadMessage = {
|
||||
const message: ThreadMessage = createMessage({
|
||||
id: messageId,
|
||||
object: 'message',
|
||||
thread_id: activeThread.id,
|
||||
assistant_id: activeAssistant.assistant_id,
|
||||
role: ChatCompletionRole.Assistant,
|
||||
content: [],
|
||||
metadata: {
|
||||
...(messageId !== parentMessageId
|
||||
? { parent_id: parentMessageId }
|
||||
: {}),
|
||||
},
|
||||
status: MessageStatus.Pending,
|
||||
created_at: Date.now() / 1000,
|
||||
completed_at: Date.now() / 1000,
|
||||
}
|
||||
})
|
||||
events.emit(MessageEvent.OnMessageResponse, message)
|
||||
const response = await openai.chat.completions.create({
|
||||
// Variables to track and accumulate streaming content
|
||||
|
||||
if (
|
||||
data.model?.parameters?.stream &&
|
||||
data.model?.engine !== InferenceEngine.cortex &&
|
||||
data.model?.engine !== InferenceEngine.cortex_llamacpp
|
||||
) {
|
||||
const response = await tokenJS.chat.completions.create({
|
||||
stream: true,
|
||||
provider,
|
||||
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||
model: data.model?.id ?? '',
|
||||
tools: data.tools as ChatCompletionTool[],
|
||||
tool_choice: 'auto',
|
||||
})
|
||||
|
||||
if (!message.content.length) {
|
||||
message.content = emptyMessageContent
|
||||
}
|
||||
|
||||
isDone = await processStreamingResponse(
|
||||
response,
|
||||
requestBuilder,
|
||||
message
|
||||
)
|
||||
} else {
|
||||
const response = await tokenJS.chat.completions.create({
|
||||
stream: false,
|
||||
provider,
|
||||
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||
model: data.model?.id ?? '',
|
||||
tools: data.tools as ChatCompletionTool[],
|
||||
stream: data.model?.parameters?.stream ?? false,
|
||||
tool_choice: 'auto',
|
||||
})
|
||||
// Variables to track and accumulate streaming content
|
||||
if (!message.content.length) {
|
||||
message.content = [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: '',
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
]
|
||||
message.content = emptyMessageContent
|
||||
}
|
||||
if (data.model?.parameters?.stream)
|
||||
isDone = await processStreamingResponse(
|
||||
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||
requestBuilder,
|
||||
message
|
||||
)
|
||||
else {
|
||||
isDone = await processNonStreamingResponse(
|
||||
response as OpenAI.Chat.Completions.ChatCompletion,
|
||||
response,
|
||||
requestBuilder,
|
||||
message
|
||||
)
|
||||
@ -336,6 +369,23 @@ export default function useSendChatMessage(
|
||||
.get(InferenceEngine.cortex)
|
||||
?.inference(requestBuilder.build())
|
||||
}
|
||||
} catch (error) {
|
||||
setIsGeneratingResponse(false)
|
||||
updateThreadWaiting(activeThread.id, false)
|
||||
const errorMessage: ThreadMessage = createMessage({
|
||||
thread_id: activeThread.id,
|
||||
assistant_id: activeAssistant.assistant_id,
|
||||
content: createMessageContent(
|
||||
typeof error === 'object' && error && 'message' in error
|
||||
? (error as { message: string }).message
|
||||
: JSON.stringify(error)
|
||||
),
|
||||
})
|
||||
events.emit(MessageEvent.OnMessageResponse, errorMessage)
|
||||
|
||||
errorMessage.status = MessageStatus.Error
|
||||
events.emit(MessageEvent.OnMessageUpdate, errorMessage)
|
||||
}
|
||||
|
||||
// Reset states
|
||||
setReloadModel(false)
|
||||
@ -343,7 +393,7 @@ export default function useSendChatMessage(
|
||||
}
|
||||
|
||||
const processNonStreamingResponse = async (
|
||||
response: OpenAI.Chat.Completions.ChatCompletion,
|
||||
response: CompletionResponse,
|
||||
requestBuilder: MessageRequestBuilder,
|
||||
message: ThreadMessage
|
||||
): Promise<boolean> => {
|
||||
@ -351,15 +401,7 @@ export default function useSendChatMessage(
|
||||
const toolCalls: ChatCompletionMessageToolCall[] =
|
||||
response.choices[0]?.message?.tool_calls ?? []
|
||||
const content = response.choices[0].message?.content
|
||||
message.content = [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: content ?? '',
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
]
|
||||
message.content = createMessageContent(content ?? '')
|
||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||
await postMessageProcessing(
|
||||
toolCalls ?? [],
|
||||
@ -371,7 +413,7 @@ export default function useSendChatMessage(
|
||||
}
|
||||
|
||||
const processStreamingResponse = async (
|
||||
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||
response: StreamCompletionResponse,
|
||||
requestBuilder: MessageRequestBuilder,
|
||||
message: ThreadMessage
|
||||
): Promise<boolean> => {
|
||||
@ -428,15 +470,7 @@ export default function useSendChatMessage(
|
||||
const content = chunk.choices[0].delta.content
|
||||
accumulatedContent += content
|
||||
|
||||
message.content = [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: accumulatedContent,
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
]
|
||||
message.content = createMessageContent(accumulatedContent)
|
||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,6 +65,7 @@
|
||||
"swr": "^2.2.5",
|
||||
"tailwind-merge": "^2.0.0",
|
||||
"tailwindcss": "3.4.17",
|
||||
"token.js": "npm:token.js-fork@0.7.2",
|
||||
"ulidx": "^2.3.0",
|
||||
"use-debounce": "^10.0.0",
|
||||
"uuid": "^9.0.1",
|
||||
|
||||
@ -32,6 +32,8 @@ import { twMerge } from 'tailwind-merge'
|
||||
import Spinner from '@/containers/Loader/Spinner'
|
||||
|
||||
import {
|
||||
builtInEngines,
|
||||
normalizeBuiltInEngineName,
|
||||
updateEngine,
|
||||
useGetEngines,
|
||||
useRefreshModelList,
|
||||
@ -366,6 +368,10 @@ const RemoteEngineSettings = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!builtInEngines.includes(
|
||||
normalizeBuiltInEngineName(engineName) ?? ''
|
||||
) && (
|
||||
<>
|
||||
<div className="block w-full px-4">
|
||||
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
|
||||
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
|
||||
@ -376,8 +382,8 @@ const RemoteEngineSettings = ({
|
||||
Request Headers Template
|
||||
</h6>
|
||||
<p className="mt-1 text-[hsla(var(--text-secondary))]">
|
||||
HTTP headers template required for API authentication
|
||||
and version specification.
|
||||
HTTP headers template required for API
|
||||
authentication and version specification.
|
||||
</p>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
@ -407,8 +413,8 @@ const RemoteEngineSettings = ({
|
||||
Request Format Conversion
|
||||
</h6>
|
||||
<p className="mt-1 text-[hsla(var(--text-secondary))]">
|
||||
Template to transform OpenAI-compatible requests into
|
||||
provider-specific format.
|
||||
Template to transform OpenAI-compatible requests
|
||||
into provider-specific format.
|
||||
</p>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
@ -465,6 +471,8 @@ const RemoteEngineSettings = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</ScrollArea>
|
||||
|
||||
@ -22,6 +22,8 @@ export const Routes = [
|
||||
'saveMcpConfigs',
|
||||
'getMcpConfigs',
|
||||
'restartMcpServers',
|
||||
'relaunch',
|
||||
'installExtensions',
|
||||
].map((r) => ({
|
||||
path: `app`,
|
||||
route: r,
|
||||
|
||||
45
web/utils/createMessage.ts
Normal file
45
web/utils/createMessage.ts
Normal file
@ -0,0 +1,45 @@
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
MessageStatus,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
export const emptyMessageContent: ThreadContent[] = [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: '',
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
export const createMessageContent = (text: string): ThreadContent[] => {
|
||||
return [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: text,
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
export const createMessage = (opts: Partial<ThreadMessage>): ThreadMessage => {
|
||||
return {
|
||||
id: opts.id ?? ulid(),
|
||||
object: 'message',
|
||||
thread_id: opts.thread_id ?? '',
|
||||
assistant_id: opts.assistant_id ?? '',
|
||||
role: opts.role ?? ChatCompletionRole.Assistant,
|
||||
content: opts.content ?? [],
|
||||
metadata: opts.metadata ?? {},
|
||||
status: opts.status ?? MessageStatus.Pending,
|
||||
created_at: opts.created_at ?? Date.now() / 1000,
|
||||
completed_at: opts.completed_at ?? Date.now() / 1000,
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user