feat: simplify remote providers and tool use capability (#4970)

* feat: built-in remote providers go to tokenjs

* fix: error handling

* fix: extend models

* chore: error handling

* chore: update advanced settings of built-in providers

* chore: clean up message creation

* chore: fix import

* fix: engine name

* fix: error handling
This commit is contained in:
Louis 2025-05-09 16:25:36 +07:00
parent 7748f0c7e1
commit f3a808cb89
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
9 changed files with 420 additions and 273 deletions

View File

@ -5,7 +5,7 @@
"url": "https://aistudio.google.com/apikey",
"api_key": "",
"metadata": {
"get_models_url": "https://generativelanguage.googleapis.com/v1beta/models",
"get_models_url": "https://generativelanguage.googleapis.com/openai/v1beta/models",
"header_template": "Authorization: Bearer {{api_key}}",
"transform_req": {
"chat_completions": {

View File

@ -107,7 +107,6 @@ mod tests {
use super::*;
use std::fs::{self, File};
use std::io::Write;
use serde_json::to_string;
use tauri::test::mock_app;
#[test]

View File

@ -1,3 +1,4 @@
import 'openai/shims/web'
import { useCallback, useMemo, useState } from 'react'
import {
@ -18,11 +19,66 @@ import { useAtom, useAtomValue } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
import useSWR from 'swr'
import { models, TokenJS } from 'token.js'
import { LLMProvider } from 'token.js/dist/chat'
import { getDescriptionByEngine, getTitleByEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension/ExtensionManager'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
export const builtInEngines = [
'openai',
'ai21',
'anthropic',
'gemini',
'cohere',
'bedrock',
'mistral',
'groq',
'perplexity',
'openrouter',
'openai-compatible',
]
export const convertBuiltInEngine = (engine?: string): LLMProvider => {
const engineName = normalizeBuiltInEngineName(engine) ?? ''
return (
builtInEngines.includes(engineName) ? engineName : 'openai-compatible'
) as LLMProvider
}
export const normalizeBuiltInEngineName = (
engine?: string
): string | undefined => {
return engine === ('google_gemini' as InferenceEngine) ? 'gemini' : engine
}
export const extendBuiltInEngineModels = (
tokenJS: TokenJS,
provider: LLMProvider,
model?: string
) => {
if (provider !== 'openrouter' && provider !== 'openai-compatible' && model) {
if (
provider in Object.keys(models) &&
(models[provider].models as unknown as string[]).includes(model)
) {
return
}
try {
// @ts-expect-error Unknown extendModelList provider type
tokenJS.extendModelList(provider, model, {
streaming: true,
toolCalls: true,
})
} catch (error) {
console.error('Failed to extend model list:', error)
}
}
}
export const releasedEnginesCacheAtom = atomWithStorage<{
data: EngineReleased[]
timestamp: number

View File

@ -48,7 +48,7 @@ export default function useFactoryReset() {
// 2: Delete the old jan data folder
setFactoryResetState(FactoryResetState.DeletingData)
await fs.rm({ args: [janDataFolderPath] })
await fs.rm(janDataFolderPath)
// 3: Set the default jan data folder
if (!keepCurrentFolder) {
@ -61,6 +61,8 @@ export default function useFactoryReset() {
await window.core?.api?.updateAppConfiguration({ configuration })
}
await window.core?.api?.installExtensions()
// Perform factory reset
// await window.core?.api?.factoryReset()

View File

@ -1,3 +1,5 @@
import 'openai/shims/web'
import { useEffect, useRef } from 'react'
import {
@ -18,16 +20,19 @@ import {
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessageParam,
ChatCompletionRole as OpenAIChatCompletionRole,
ChatCompletionTool,
ChatCompletionMessageToolCall,
} from 'openai/resources/chat'
import { Stream } from 'openai/streaming'
import {
CompletionResponse,
StreamCompletionResponse,
TokenJS,
models,
} from 'token.js'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
@ -38,12 +43,23 @@ import {
} from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64'
import {
createMessage,
createMessageContent,
emptyMessageContent,
} from '@/utils/createMessage'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel'
import {
convertBuiltInEngine,
extendBuiltInEngineModels,
useGetEngines,
} from './useEngineManagement'
import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
@ -100,6 +116,8 @@ export default function useSendChatMessage(
const selectedModelRef = useRef<Model | undefined>()
const { engines } = useGetEngines()
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
@ -167,174 +185,206 @@ export default function useSendChatMessage(
setCurrentPrompt('')
setEditPrompt('')
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
try {
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
if (base64Blob && fileUpload?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
const modelRequest = selectedModel ?? activeAssistant.model
// Fallback support for previous broken threads
if (activeAssistant.model?.id === '*') {
activeAssistant.model = {
id: currentModel.id,
settings: currentModel.settings,
parameters: currentModel.parameters,
if (base64Blob && fileUpload?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
}
if (runtimeParams.stream == null) {
runtimeParams.stream = true
}
// Build Message Request
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
...modelRequest,
settings: settingParams,
parameters: runtimeParams,
},
activeThread,
messages ?? currentMessages,
(await window.core.api.getTools())
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
).addSystemMessage(activeAssistant.instructions)
const modelRequest = selectedModel ?? activeAssistant.model
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
requestBuilder
).pushMessage(prompt, base64Blob, fileUpload)
const newMessage = threadMessageBuilder.build()
// Update thread state
const updatedThread: Thread = {
...activeThread,
updated: newMessage.created_at,
metadata: {
...activeThread.metadata,
lastMessage: prompt,
},
}
updateThread(updatedThread)
if (
!isResend &&
(newMessage.content.length || newMessage.attachments?.length)
) {
// Add message
const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(newMessage)
.catch(() => undefined)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThread.id, false)
return
// Fallback support for previous broken threads
if (activeAssistant.model?.id === '*') {
activeAssistant.model = {
id: currentModel.id,
settings: currentModel.settings,
parameters: currentModel.parameters,
}
}
if (runtimeParams.stream == null) {
runtimeParams.stream = true
}
}
setIsGeneratingResponse(true)
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
const openai = new OpenAI({
apiKey: await window.core.api.appToken(),
baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
})
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = {
id: messageId,
object: 'message',
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
role: ChatCompletionRole.Assistant,
content: [],
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
status: MessageStatus.Pending,
created_at: Date.now() / 1000,
completed_at: Date.now() / 1000,
}
events.emit(MessageEvent.OnMessageResponse, message)
const response = await openai.chat.completions.create({
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: data.model?.parameters?.stream ?? false,
tool_choice: 'auto',
})
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
// Build Message Request
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
...modelRequest,
settings: settingParams,
parameters: runtimeParams,
},
activeThread,
messages ?? currentMessages,
(await window.core.api.getTools())
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
]
}
if (data.model?.parameters?.stream)
isDone = await processStreamingResponse(
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder,
message
)
else {
isDone = await processNonStreamingResponse(
response as OpenAI.Chat.Completions.ChatCompletion,
requestBuilder,
message
)
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}))
).addSystemMessage(activeAssistant.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
requestBuilder
).pushMessage(prompt, base64Blob, fileUpload)
const newMessage = threadMessageBuilder.build()
// Update thread state
const updatedThread: Thread = {
...activeThread,
updated: newMessage.created_at,
metadata: {
...activeThread.metadata,
lastMessage: prompt,
},
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
updateThread(updatedThread)
if (
!isResend &&
(newMessage.content.length || newMessage.attachments?.length)
) {
// Add message
const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(newMessage)
.catch(() => undefined)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThread.id, false)
return
}
}
setIsGeneratingResponse(true)
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
const engine =
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
const apiKey = engine?.api_key
const provider = convertBuiltInEngine(engine?.engine)
const tokenJS = new TokenJS({
apiKey: apiKey ?? (await window.core.api.appToken()),
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
})
extendBuiltInEngineModels(tokenJS, provider, modelId)
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = createMessage({
id: messageId,
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
})
events.emit(MessageEvent.OnMessageResponse, message)
// Variables to track and accumulate streaming content
if (
data.model?.parameters?.stream &&
data.model?.engine !== InferenceEngine.cortex &&
data.model?.engine !== InferenceEngine.cortex_llamacpp
) {
const response = await tokenJS.chat.completions.create({
stream: true,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: 'auto',
})
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processStreamingResponse(
response,
requestBuilder,
message
)
} else {
const response = await tokenJS.chat.completions.create({
stream: false,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: 'auto',
})
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processNonStreamingResponse(
response,
requestBuilder,
message
)
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
}
} catch (error) {
setIsGeneratingResponse(false)
updateThreadWaiting(activeThread.id, false)
const errorMessage: ThreadMessage = createMessage({
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
content: createMessageContent(
typeof error === 'object' && error && 'message' in error
? (error as { message: string }).message
: JSON.stringify(error)
),
})
events.emit(MessageEvent.OnMessageResponse, errorMessage)
errorMessage.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, errorMessage)
}
// Reset states
@ -343,7 +393,7 @@ export default function useSendChatMessage(
}
const processNonStreamingResponse = async (
response: OpenAI.Chat.Completions.ChatCompletion,
response: CompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
@ -351,15 +401,7 @@ export default function useSendChatMessage(
const toolCalls: ChatCompletionMessageToolCall[] =
response.choices[0]?.message?.tool_calls ?? []
const content = response.choices[0].message?.content
message.content = [
{
type: ContentType.Text,
text: {
value: content ?? '',
annotations: [],
},
},
]
message.content = createMessageContent(content ?? '')
events.emit(MessageEvent.OnMessageUpdate, message)
await postMessageProcessing(
toolCalls ?? [],
@ -371,7 +413,7 @@ export default function useSendChatMessage(
}
const processStreamingResponse = async (
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
response: StreamCompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
@ -428,15 +470,7 @@ export default function useSendChatMessage(
const content = chunk.choices[0].delta.content
accumulatedContent += content
message.content = [
{
type: ContentType.Text,
text: {
value: accumulatedContent,
annotations: [],
},
},
]
message.content = createMessageContent(accumulatedContent)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}

View File

@ -65,6 +65,7 @@
"swr": "^2.2.5",
"tailwind-merge": "^2.0.0",
"tailwindcss": "3.4.17",
"token.js": "npm:token.js-fork@0.7.2",
"ulidx": "^2.3.0",
"use-debounce": "^10.0.0",
"uuid": "^9.0.1",

View File

@ -32,6 +32,8 @@ import { twMerge } from 'tailwind-merge'
import Spinner from '@/containers/Loader/Spinner'
import {
builtInEngines,
normalizeBuiltInEngineName,
updateEngine,
useGetEngines,
useRefreshModelList,
@ -366,105 +368,111 @@ const RemoteEngineSettings = ({
</div>
</div>
</div>
<div className="block w-full px-4">
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Request Headers Template
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
HTTP headers template required for API authentication
and version specification.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter headers template"
value={data?.metadata?.header_template}
onChange={(e) =>
handleChange(
'metadata.header_template',
e.target.value
)
}
/>
{!builtInEngines.includes(
normalizeBuiltInEngineName(engineName) ?? ''
) && (
<>
<div className="block w-full px-4">
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Request Headers Template
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
HTTP headers template required for API
authentication and version specification.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter headers template"
value={data?.metadata?.header_template}
onChange={(e) =>
handleChange(
'metadata.header_template',
e.target.value
)
}
/>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<div className="block w-full px-4">
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Request Format Conversion
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
Template to transform OpenAI-compatible requests into
provider-specific format.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter conversion function"
value={
data?.metadata?.transform_req?.chat_completions
?.template
}
onChange={(e) =>
handleChange(
'metadata.transform_req.chat_completions.template',
e.target.value
)
}
/>
<div className="block w-full px-4">
<div className="mb-3 mt-4 border-b border-[hsla(var(--app-border))] pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Request Format Conversion
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
Template to transform OpenAI-compatible requests
into provider-specific format.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter conversion function"
value={
data?.metadata?.transform_req?.chat_completions
?.template
}
onChange={(e) =>
handleChange(
'metadata.transform_req.chat_completions.template',
e.target.value
)
}
/>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<div className="block w-full px-4">
<div className="mb-3 mt-4 pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Response Format Conversion
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
Template to transform provider responses into
OpenAI-compatible format.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter conversion function"
value={
data?.metadata?.transform_resp?.chat_completions
?.template
}
onChange={(e) =>
handleChange(
'metadata.transform_resp.chat_completions.template',
e.target.value
)
}
/>
<div className="block w-full px-4">
<div className="mb-3 mt-4 pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 space-y-1.5">
<div className="flex items-start justify-between gap-x-2">
<div className="w-full sm:w-3/4">
<h6 className="line-clamp-1 font-semibold">
Response Format Conversion
</h6>
<p className="mt-1 text-[hsla(var(--text-secondary))]">
Template to transform provider responses into
OpenAI-compatible format.
</p>
</div>
<div className="w-full">
<TextArea
placeholder="Enter conversion function"
value={
data?.metadata?.transform_resp?.chat_completions
?.template
}
onChange={(e) =>
handleChange(
'metadata.transform_resp.chat_completions.template',
e.target.value
)
}
/>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</>
)}
</div>
)}
</ScrollArea>

View File

@ -22,6 +22,8 @@ export const Routes = [
'saveMcpConfigs',
'getMcpConfigs',
'restartMcpServers',
'relaunch',
'installExtensions',
].map((r) => ({
path: `app`,
route: r,

View File

@ -0,0 +1,45 @@
import {
ChatCompletionRole,
ContentType,
MessageStatus,
ThreadContent,
ThreadMessage,
} from '@janhq/core'
import { ulid } from 'ulidx'
export const emptyMessageContent: ThreadContent[] = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
},
]
export const createMessageContent = (text: string): ThreadContent[] => {
return [
{
type: ContentType.Text,
text: {
value: text,
annotations: [],
},
},
]
}
export const createMessage = (opts: Partial<ThreadMessage>): ThreadMessage => {
return {
id: opts.id ?? ulid(),
object: 'message',
thread_id: opts.thread_id ?? '',
assistant_id: opts.assistant_id ?? '',
role: opts.role ?? ChatCompletionRole.Assistant,
content: opts.content ?? [],
metadata: opts.metadata ?? {},
status: opts.status ?? MessageStatus.Pending,
created_at: opts.created_at ?? Date.now() / 1000,
completed_at: opts.completed_at ?? Date.now() / 1000,
}
}