diff --git a/extensions/engine-management-extension/resources/google_gemini.json b/extensions/engine-management-extension/resources/google_gemini.json index e0fa809a5..f860a1990 100644 --- a/extensions/engine-management-extension/resources/google_gemini.json +++ b/extensions/engine-management-extension/resources/google_gemini.json @@ -5,7 +5,7 @@ "url": "https://aistudio.google.com/apikey", "api_key": "", "metadata": { - "get_models_url": "https://generativelanguage.googleapis.com/v1beta/models", + "get_models_url": "https://generativelanguage.googleapis.com/openai/v1beta/models", "header_template": "Authorization: Bearer {{api_key}}", "transform_req": { "chat_completions": { diff --git a/src-tauri/src/core/fs.rs b/src-tauri/src/core/fs.rs index c0d7d423d..66486cf0a 100644 --- a/src-tauri/src/core/fs.rs +++ b/src-tauri/src/core/fs.rs @@ -107,7 +107,6 @@ mod tests { use super::*; use std::fs::{self, File}; use std::io::Write; - use serde_json::to_string; use tauri::test::mock_app; #[test] diff --git a/web/hooks/useEngineManagement.ts b/web/hooks/useEngineManagement.ts index d9eacb592..8c19737ac 100644 --- a/web/hooks/useEngineManagement.ts +++ b/web/hooks/useEngineManagement.ts @@ -1,3 +1,4 @@ +import 'openai/shims/web' import { useCallback, useMemo, useState } from 'react' import { @@ -18,11 +19,66 @@ import { useAtom, useAtomValue } from 'jotai' import { atomWithStorage } from 'jotai/utils' import useSWR from 'swr' +import { models, TokenJS } from 'token.js' +import { LLMProvider } from 'token.js/dist/chat' + import { getDescriptionByEngine, getTitleByEngine } from '@/utils/modelEngine' import { extensionManager } from '@/extension/ExtensionManager' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +export const builtInEngines = [ + 'openai', + 'ai21', + 'anthropic', + 'gemini', + 'cohere', + 'bedrock', + 'mistral', + 'groq', + 'perplexity', + 'openrouter', + 'openai-compatible', +] + +export const convertBuiltInEngine = (engine?: string): LLMProvider => { + const engineName = normalizeBuiltInEngineName(engine) ?? '' + return ( + builtInEngines.includes(engineName) ? engineName : 'openai-compatible' + ) as LLMProvider +} + +export const normalizeBuiltInEngineName = ( + engine?: string +): string | undefined => { + return engine === ('google_gemini' as InferenceEngine) ? 'gemini' : engine +} + +export const extendBuiltInEngineModels = ( + tokenJS: TokenJS, + provider: LLMProvider, + model?: string +) => { + if (provider !== 'openrouter' && provider !== 'openai-compatible' && model) { + if ( + provider in Object.keys(models) && + (models[provider].models as unknown as string[]).includes(model) + ) { + return + } + + try { + // @ts-expect-error Unknown extendModelList provider type + tokenJS.extendModelList(provider, model, { + streaming: true, + toolCalls: true, + }) + } catch (error) { + console.error('Failed to extend model list:', error) + } + } +} + export const releasedEnginesCacheAtom = atomWithStorage<{ data: EngineReleased[] timestamp: number diff --git a/web/hooks/useFactoryReset.ts b/web/hooks/useFactoryReset.ts index 90723f9cd..c582cf685 100644 --- a/web/hooks/useFactoryReset.ts +++ b/web/hooks/useFactoryReset.ts @@ -48,7 +48,7 @@ export default function useFactoryReset() { // 2: Delete the old jan data folder setFactoryResetState(FactoryResetState.DeletingData) - await fs.rm({ args: [janDataFolderPath] }) + await fs.rm(janDataFolderPath) // 3: Set the default jan data folder if (!keepCurrentFolder) { @@ -61,6 +61,8 @@ export default function useFactoryReset() { await window.core?.api?.updateAppConfiguration({ configuration }) } + await window.core?.api?.installExtensions() + // Perform factory reset // await window.core?.api?.factoryReset() diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index cd6ff2692..4be7ac035 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -1,3 +1,5 @@ +import 'openai/shims/web' + import { useEffect, useRef } from 'react' import { @@ -18,16 +20,19 @@ import { } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' -import { OpenAI } from 'openai' import { ChatCompletionMessageParam, - ChatCompletionRole as OpenAIChatCompletionRole, ChatCompletionTool, ChatCompletionMessageToolCall, } from 'openai/resources/chat' -import { Stream } from 'openai/streaming' +import { + CompletionResponse, + StreamCompletionResponse, + TokenJS, + models, +} from 'token.js' import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -38,12 +43,23 @@ import { } from '@/containers/Providers/Jotai' import { compressImage, getBase64 } from '@/utils/base64' +import { + createMessage, + createMessageContent, + emptyMessageContent, +} from '@/utils/createMessage' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' import { useActiveModel } from './useActiveModel' +import { + convertBuiltInEngine, + extendBuiltInEngineModels, + useGetEngines, +} from './useEngineManagement' + import { extensionManager } from '@/extension/ExtensionManager' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { @@ -100,6 +116,8 @@ export default function useSendChatMessage( const selectedModelRef = useRef() + const { engines } = useGetEngines() + useEffect(() => { modelRef.current = activeModel }, [activeModel]) @@ -167,174 +185,206 @@ export default function useSendChatMessage( setCurrentPrompt('') setEditPrompt('') - let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined + try { + let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined - if (base64Blob && fileUpload?.type === 'image') { - // Compress image - base64Blob = await compressImage(base64Blob, 512) - } - - const modelRequest = selectedModel ?? activeAssistant.model - - // Fallback support for previous broken threads - if (activeAssistant.model?.id === '*') { - activeAssistant.model = { - id: currentModel.id, - settings: currentModel.settings, - parameters: currentModel.parameters, + if (base64Blob && fileUpload?.type === 'image') { + // Compress image + base64Blob = await compressImage(base64Blob, 512) } - } - if (runtimeParams.stream == null) { - runtimeParams.stream = true - } - // Build Message Request - const requestBuilder = new MessageRequestBuilder( - MessageRequestType.Thread, - { - ...modelRequest, - settings: settingParams, - parameters: runtimeParams, - }, - activeThread, - messages ?? currentMessages, - (await window.core.api.getTools()) - ?.filter((tool: ModelTool) => !disabledTools.includes(tool.name)) - .map((tool: ModelTool) => ({ - type: 'function' as const, - function: { - name: tool.name, - description: tool.description?.slice(0, 1024), - parameters: tool.inputSchema, - strict: false, - }, - })) - ).addSystemMessage(activeAssistant.instructions) + const modelRequest = selectedModel ?? activeAssistant.model - requestBuilder.pushMessage(prompt, base64Blob, fileUpload) - - // Build Thread Message to persist - const threadMessageBuilder = new ThreadMessageBuilder( - requestBuilder - ).pushMessage(prompt, base64Blob, fileUpload) - - const newMessage = threadMessageBuilder.build() - - // Update thread state - const updatedThread: Thread = { - ...activeThread, - updated: newMessage.created_at, - metadata: { - ...activeThread.metadata, - lastMessage: prompt, - }, - } - updateThread(updatedThread) - - if ( - !isResend && - (newMessage.content.length || newMessage.attachments?.length) - ) { - // Add message - const createdMessage = await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.createMessage(newMessage) - .catch(() => undefined) - - if (!createdMessage) return - - // Push to states - addNewMessage(createdMessage) - } - - // Start Model if not started - const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id - - if (base64Blob) { - setFileUpload(undefined) - } - - if (modelRef.current?.id !== modelId && modelId) { - const error = await startModel(modelId).catch((error: Error) => error) - if (error) { - updateThreadWaiting(activeThread.id, false) - return + // Fallback support for previous broken threads + if (activeAssistant.model?.id === '*') { + activeAssistant.model = { + id: currentModel.id, + settings: currentModel.settings, + parameters: currentModel.parameters, + } + } + if (runtimeParams.stream == null) { + runtimeParams.stream = true } - } - setIsGeneratingResponse(true) - if (requestBuilder.tools && requestBuilder.tools.length) { - let isDone = false - const openai = new OpenAI({ - apiKey: await window.core.api.appToken(), - baseURL: `${API_BASE_URL}/v1`, - dangerouslyAllowBrowser: true, - }) - let parentMessageId: string | undefined - while (!isDone) { - let messageId = ulid() - if (!parentMessageId) { - parentMessageId = ulid() - messageId = parentMessageId - } - const data = requestBuilder.build() - const message: ThreadMessage = { - id: messageId, - object: 'message', - thread_id: activeThread.id, - assistant_id: activeAssistant.assistant_id, - role: ChatCompletionRole.Assistant, - content: [], - metadata: { - ...(messageId !== parentMessageId - ? { parent_id: parentMessageId } - : {}), - }, - status: MessageStatus.Pending, - created_at: Date.now() / 1000, - completed_at: Date.now() / 1000, - } - events.emit(MessageEvent.OnMessageResponse, message) - const response = await openai.chat.completions.create({ - messages: requestBuilder.messages as ChatCompletionMessageParam[], - model: data.model?.id ?? '', - tools: data.tools as ChatCompletionTool[], - stream: data.model?.parameters?.stream ?? false, - tool_choice: 'auto', - }) - // Variables to track and accumulate streaming content - if (!message.content.length) { - message.content = [ - { - type: ContentType.Text, - text: { - value: '', - annotations: [], - }, + // Build Message Request + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + { + ...modelRequest, + settings: settingParams, + parameters: runtimeParams, + }, + activeThread, + messages ?? currentMessages, + (await window.core.api.getTools()) + ?.filter((tool: ModelTool) => !disabledTools.includes(tool.name)) + .map((tool: ModelTool) => ({ + type: 'function' as const, + function: { + name: tool.name, + description: tool.description?.slice(0, 1024), + parameters: tool.inputSchema, + strict: false, }, - ] - } - if (data.model?.parameters?.stream) - isDone = await processStreamingResponse( - response as Stream, - requestBuilder, - message - ) - else { - isDone = await processNonStreamingResponse( - response as OpenAI.Chat.Completions.ChatCompletion, - requestBuilder, - message - ) - } - message.status = MessageStatus.Ready - events.emit(MessageEvent.OnMessageUpdate, message) + })) + ).addSystemMessage(activeAssistant.instructions) + + requestBuilder.pushMessage(prompt, base64Blob, fileUpload) + + // Build Thread Message to persist + const threadMessageBuilder = new ThreadMessageBuilder( + requestBuilder + ).pushMessage(prompt, base64Blob, fileUpload) + + const newMessage = threadMessageBuilder.build() + + // Update thread state + const updatedThread: Thread = { + ...activeThread, + updated: newMessage.created_at, + metadata: { + ...activeThread.metadata, + lastMessage: prompt, + }, } - } else { - // Request for inference - EngineManager.instance() - .get(InferenceEngine.cortex) - ?.inference(requestBuilder.build()) + updateThread(updatedThread) + + if ( + !isResend && + (newMessage.content.length || newMessage.attachments?.length) + ) { + // Add message + const createdMessage = await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(newMessage) + .catch(() => undefined) + + if (!createdMessage) return + + // Push to states + addNewMessage(createdMessage) + } + + // Start Model if not started + const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id + + if (base64Blob) { + setFileUpload(undefined) + } + + if (modelRef.current?.id !== modelId && modelId) { + const error = await startModel(modelId).catch((error: Error) => error) + if (error) { + updateThreadWaiting(activeThread.id, false) + return + } + } + setIsGeneratingResponse(true) + + if (requestBuilder.tools && requestBuilder.tools.length) { + let isDone = false + + const engine = + engines?.[requestBuilder.model.engine as InferenceEngine]?.[0] + const apiKey = engine?.api_key + const provider = convertBuiltInEngine(engine?.engine) + + const tokenJS = new TokenJS({ + apiKey: apiKey ?? (await window.core.api.appToken()), + baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`, + }) + + extendBuiltInEngineModels(tokenJS, provider, modelId) + + let parentMessageId: string | undefined + while (!isDone) { + let messageId = ulid() + if (!parentMessageId) { + parentMessageId = ulid() + messageId = parentMessageId + } + const data = requestBuilder.build() + const message: ThreadMessage = createMessage({ + id: messageId, + thread_id: activeThread.id, + assistant_id: activeAssistant.assistant_id, + metadata: { + ...(messageId !== parentMessageId + ? { parent_id: parentMessageId } + : {}), + }, + }) + events.emit(MessageEvent.OnMessageResponse, message) + // Variables to track and accumulate streaming content + + if ( + data.model?.parameters?.stream && + data.model?.engine !== InferenceEngine.cortex && + data.model?.engine !== InferenceEngine.cortex_llamacpp + ) { + const response = await tokenJS.chat.completions.create({ + stream: true, + provider, + messages: requestBuilder.messages as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + tool_choice: 'auto', + }) + + if (!message.content.length) { + message.content = emptyMessageContent + } + + isDone = await processStreamingResponse( + response, + requestBuilder, + message + ) + } else { + const response = await tokenJS.chat.completions.create({ + stream: false, + provider, + messages: requestBuilder.messages as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + tool_choice: 'auto', + }) + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = emptyMessageContent + } + isDone = await processNonStreamingResponse( + response, + requestBuilder, + message + ) + } + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) + } + } else { + // Request for inference + EngineManager.instance() + .get(InferenceEngine.cortex) + ?.inference(requestBuilder.build()) + } + } catch (error) { + setIsGeneratingResponse(false) + updateThreadWaiting(activeThread.id, false) + const errorMessage: ThreadMessage = createMessage({ + thread_id: activeThread.id, + assistant_id: activeAssistant.assistant_id, + content: createMessageContent( + typeof error === 'object' && error && 'message' in error + ? (error as { message: string }).message + : JSON.stringify(error) + ), + }) + events.emit(MessageEvent.OnMessageResponse, errorMessage) + + errorMessage.status = MessageStatus.Error + events.emit(MessageEvent.OnMessageUpdate, errorMessage) } // Reset states @@ -343,7 +393,7 @@ export default function useSendChatMessage( } const processNonStreamingResponse = async ( - response: OpenAI.Chat.Completions.ChatCompletion, + response: CompletionResponse, requestBuilder: MessageRequestBuilder, message: ThreadMessage ): Promise => { @@ -351,15 +401,7 @@ export default function useSendChatMessage( const toolCalls: ChatCompletionMessageToolCall[] = response.choices[0]?.message?.tool_calls ?? [] const content = response.choices[0].message?.content - message.content = [ - { - type: ContentType.Text, - text: { - value: content ?? '', - annotations: [], - }, - }, - ] + message.content = createMessageContent(content ?? '') events.emit(MessageEvent.OnMessageUpdate, message) await postMessageProcessing( toolCalls ?? [], @@ -371,7 +413,7 @@ export default function useSendChatMessage( } const processStreamingResponse = async ( - response: Stream, + response: StreamCompletionResponse, requestBuilder: MessageRequestBuilder, message: ThreadMessage ): Promise => { @@ -428,15 +470,7 @@ export default function useSendChatMessage( const content = chunk.choices[0].delta.content accumulatedContent += content - message.content = [ - { - type: ContentType.Text, - text: { - value: accumulatedContent, - annotations: [], - }, - }, - ] + message.content = createMessageContent(accumulatedContent) events.emit(MessageEvent.OnMessageUpdate, message) } } diff --git a/web/package.json b/web/package.json index cdf2d8d8b..cf76a95f2 100644 --- a/web/package.json +++ b/web/package.json @@ -65,6 +65,7 @@ "swr": "^2.2.5", "tailwind-merge": "^2.0.0", "tailwindcss": "3.4.17", + "token.js": "npm:token.js-fork@0.7.2", "ulidx": "^2.3.0", "use-debounce": "^10.0.0", "uuid": "^9.0.1", diff --git a/web/screens/Settings/Engines/RemoteEngineSettings.tsx b/web/screens/Settings/Engines/RemoteEngineSettings.tsx index 1ddacd432..e773b1957 100644 --- a/web/screens/Settings/Engines/RemoteEngineSettings.tsx +++ b/web/screens/Settings/Engines/RemoteEngineSettings.tsx @@ -32,6 +32,8 @@ import { twMerge } from 'tailwind-merge' import Spinner from '@/containers/Loader/Spinner' import { + builtInEngines, + normalizeBuiltInEngineName, updateEngine, useGetEngines, useRefreshModelList, @@ -366,105 +368,111 @@ const RemoteEngineSettings = ({ -
-
-
-
-
-
-
- Request Headers Template -
-

- HTTP headers template required for API authentication - and version specification. -

-
-
-