feat: simplify remote providers and tool use capability (#4970)

* feat: built-in remote providers go to tokenjs

* fix: error handling

* fix: extend models

* chore: error handling

* chore: update advanced settings of built-in providers

* chore: clean up message creation

* chore: fix import

* fix: engine name

* fix: error handling
This commit is contained in:
Louis 2025-05-09 16:25:36 +07:00
parent 7748f0c7e1
commit f3a808cb89
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
9 changed files with 420 additions and 273 deletions

View File

@ -5,7 +5,7 @@
"url": "https://aistudio.google.com/apikey",
"api_key": "",
"metadata": {
"get_models_url": "https://generativelanguage.googleapis.com/v1beta/models",
"get_models_url": "https://generativelanguage.googleapis.com/openai/v1beta/models",
"header_template": "Authorization: Bearer {{api_key}}",
"transform_req": {
"chat_completions": {

View File

@ -107,7 +107,6 @@ mod tests {
use super::*;
use std::fs::{self, File};
use std::io::Write;
use serde_json::to_string;
use tauri::test::mock_app;
#[test]

View File

@ -1,3 +1,4 @@
import 'openai/shims/web'
import { useCallback, useMemo, useState } from 'react'
import {
@ -18,11 +19,66 @@ import { useAtom, useAtomValue } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
import useSWR from 'swr'
import { models, TokenJS } from 'token.js'
import { LLMProvider } from 'token.js/dist/chat'
import { getDescriptionByEngine, getTitleByEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension/ExtensionManager'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
export const builtInEngines = [
'openai',
'ai21',
'anthropic',
'gemini',
'cohere',
'bedrock',
'mistral',
'groq',
'perplexity',
'openrouter',
'openai-compatible',
]
export const convertBuiltInEngine = (engine?: string): LLMProvider => {
const engineName = normalizeBuiltInEngineName(engine) ?? ''
return (
builtInEngines.includes(engineName) ? engineName : 'openai-compatible'
) as LLMProvider
}
export const normalizeBuiltInEngineName = (
engine?: string
): string | undefined => {
return engine === ('google_gemini' as InferenceEngine) ? 'gemini' : engine
}
export const extendBuiltInEngineModels = (
tokenJS: TokenJS,
provider: LLMProvider,
model?: string
) => {
if (provider !== 'openrouter' && provider !== 'openai-compatible' && model) {
if (
provider in Object.keys(models) &&
(models[provider].models as unknown as string[]).includes(model)
) {
return
}
try {
// @ts-expect-error Unknown extendModelList provider type
tokenJS.extendModelList(provider, model, {
streaming: true,
toolCalls: true,
})
} catch (error) {
console.error('Failed to extend model list:', error)
}
}
}
export const releasedEnginesCacheAtom = atomWithStorage<{
data: EngineReleased[]
timestamp: number

View File

@ -48,7 +48,7 @@ export default function useFactoryReset() {
// 2: Delete the old jan data folder
setFactoryResetState(FactoryResetState.DeletingData)
await fs.rm({ args: [janDataFolderPath] })
await fs.rm(janDataFolderPath)
// 3: Set the default jan data folder
if (!keepCurrentFolder) {
@ -61,6 +61,8 @@ export default function useFactoryReset() {
await window.core?.api?.updateAppConfiguration({ configuration })
}
await window.core?.api?.installExtensions()
// Perform factory reset
// await window.core?.api?.factoryReset()

View File

@ -1,3 +1,5 @@
import 'openai/shims/web'
import { useEffect, useRef } from 'react'
import {
@ -18,16 +20,19 @@ import {
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessageParam,
ChatCompletionRole as OpenAIChatCompletionRole,
ChatCompletionTool,
ChatCompletionMessageToolCall,
} from 'openai/resources/chat'
import { Stream } from 'openai/streaming'
import {
CompletionResponse,
StreamCompletionResponse,
TokenJS,
models,
} from 'token.js'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
@ -38,12 +43,23 @@ import {
} from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64'
import {
createMessage,
createMessageContent,
emptyMessageContent,
} from '@/utils/createMessage'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel'
import {
convertBuiltInEngine,
extendBuiltInEngineModels,
useGetEngines,
} from './useEngineManagement'
import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
@ -100,6 +116,8 @@ export default function useSendChatMessage(
const selectedModelRef = useRef<Model | undefined>()
const { engines } = useGetEngines()
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
@ -167,6 +185,7 @@ export default function useSendChatMessage(
setCurrentPrompt('')
setEditPrompt('')
try {
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
if (base64Blob && fileUpload?.type === 'image') {
@ -265,11 +284,19 @@ export default function useSendChatMessage(
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
const openai = new OpenAI({
apiKey: await window.core.api.appToken(),
baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
const engine =
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
const apiKey = engine?.api_key
const provider = convertBuiltInEngine(engine?.engine)
const tokenJS = new TokenJS({
apiKey: apiKey ?? (await window.core.api.appToken()),
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
})
extendBuiltInEngineModels(tokenJS, provider, modelId)
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
@ -278,51 +305,57 @@ export default function useSendChatMessage(
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = {
const message: ThreadMessage = createMessage({
id: messageId,
object: 'message',
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
role: ChatCompletionRole.Assistant,
content: [],
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
status: MessageStatus.Pending,
created_at: Date.now() / 1000,
completed_at: Date.now() / 1000,
}
})
events.emit(MessageEvent.OnMessageResponse, message)
const response = await openai.chat.completions.create({
// Variables to track and accumulate streaming content
if (
data.model?.parameters?.stream &&
data.model?.engine !== InferenceEngine.cortex &&
data.model?.engine !== InferenceEngine.cortex_llamacpp
) {
const response = await tokenJS.chat.completions.create({
stream: true,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: 'auto',
})
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processStreamingResponse(
response,
requestBuilder,
message
)
} else {
const response = await tokenJS.chat.completions.create({
stream: false,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: data.model?.parameters?.stream ?? false,
tool_choice: 'auto',
})
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
},
]
message.content = emptyMessageContent
}
if (data.model?.parameters?.stream)
isDone = await processStreamingResponse(
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder,
message
)
else {
isDone = await processNonStreamingResponse(
response as OpenAI.Chat.Completions.ChatCompletion,
response,
requestBuilder,
message
)
@ -336,6 +369,23 @@ export default function useSendChatMessage(
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
}
} catch (error) {
setIsGeneratingResponse(false)
updateThreadWaiting(activeThread.id, false)
const errorMessage: ThreadMessage = createMessage({
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
content: createMessageContent(
typeof error === 'object' && error && 'message' in error
? (error as { message: string }).message
: JSON.stringify(error)
),
})
events.emit(MessageEvent.OnMessageResponse, errorMessage)
errorMessage.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, errorMessage)
}
// Reset states
setReloadModel(false)
@ -343,7 +393,7 @@ export default function useSendChatMessage(
}
const processNonStreamingResponse = async (
response: OpenAI.Chat.Completions.ChatCompletion,
response: CompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
@ -351,15 +401,7 @@ export default function useSendChatMessage(
const toolCalls: ChatCompletionMessageToolCall[] =
response.choices[0]?.message?.tool_calls ?? []
const content = response.choices[0].message?.content
message.content = [
{
type: ContentType.Text,
text: {
value: content ?? '',
annotations: [],
},
},
]
message.content = createMessageContent(content ?? '')
events.emit(MessageEvent.OnMessageUpdate, message)
await postMessageProcessing(
toolCalls ?? [],
@ -371,7 +413,7 @@ export default function useSendChatMessage(
}
const processStreamingResponse = async (
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
response: StreamCompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
@ -428,15 +470,7 @@ export default function useSendChatMessage(
const content = chunk.choices[0].delta.content
accumulatedContent += content
message.content = [
{
type: ContentType.Text,
text: {
value: accumulatedContent,
annotations: [],
},
},
]
message.content = createMessageContent(accumulatedContent)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}

View File

@ -65,6 +65,7 @@
"swr": "^2.2.5",
"tailwind-merge": "^2.0.0",
"tailwindcss": "3.4.17",
"token.js": "npm:token.js-fork@0.7.2",
"ulidx": "^2.3.0",
"use-debounce": "^10.0.0",
"uuid": "^9.0.1",

View File

@ -32,6 +32,8 @@ import { twMerge } from 'tailwind-merge'
import Spinner from '@/containers/Loader/Spinner'
import {
builtInEngines,
normalizeBuiltInEngineName,
updateEngine,
useGetEngines,
useRefreshModelList,
@ -366,6 +368,10 @@ const RemoteEngineSettings = ({
</div>
</div>
</div>
{!builtInEngines.includes(
normalizeBuiltInEngineName(engineName) ?? ''
) && (
<>
<div className="block w-full px-4">
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
@ -376,8 +382,8 @@ const RemoteEngineSettings = ({
Request Headers Template
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
HTTP headers template required for API authentication
and version specification.
HTTP headers template required for API
authentication and version specification.
</p>
</div>
<div className="w-full">
@ -407,8 +413,8 @@ const RemoteEngineSettings = ({
Request Format Conversion
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
Template to transform OpenAI-compatible requests into
provider-specific format.
Template to transform OpenAI-compatible requests
into provider-specific format.
</p>
</div>
<div className="w-full">
@ -465,6 +471,8 @@ const RemoteEngineSettings = ({
</div>
</div>
</div>
</>
)}
</div>
)}
</ScrollArea>

View File

@ -22,6 +22,8 @@ export const Routes = [
'saveMcpConfigs',
'getMcpConfigs',
'restartMcpServers',
'relaunch',
'installExtensions',
].map((r) => ({
path: `app`,
route: r,

View File

@ -0,0 +1,45 @@
import {
ChatCompletionRole,
ContentType,
MessageStatus,
ThreadContent,
ThreadMessage,
} from '@janhq/core'
import { ulid } from 'ulidx'
export const emptyMessageContent: ThreadContent[] = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
},
]
export const createMessageContent = (text: string): ThreadContent[] => {
return [
{
type: ContentType.Text,
text: {
value: text,
annotations: [],
},
},
]
}
export const createMessage = (opts: Partial<ThreadMessage>): ThreadMessage => {
return {
id: opts.id ?? ulid(),
object: 'message',
thread_id: opts.thread_id ?? '',
assistant_id: opts.assistant_id ?? '',
role: opts.role ?? ChatCompletionRole.Assistant,
content: opts.content ?? [],
metadata: opts.metadata ?? {},
status: opts.status ?? MessageStatus.Pending,
created_at: opts.created_at ?? Date.now() / 1000,
completed_at: opts.completed_at ?? Date.now() / 1000,
}
}