From d5393e45637c7704f72e556944766d8c9b9f05af Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 20 May 2025 14:30:51 +0700 Subject: [PATCH] feat: add custom OpenAI provider (#5033) * feat: add custom OpenAI provider * chore: add HF token setting * chore: move HF token setting to llama.cpp provider - later deprecate model extension --- .../resources/default_settings.json | 12 +++ .../inference-cortex-extension/src/index.ts | 27 ++++++ web-app/src/containers/ProvidersMenu.tsx | 81 +++++++++++++++- .../src/containers/dialogs/DeleteProvider.tsx | 93 +++++++++++++++++++ web-app/src/hooks/useChat.ts | 1 + web-app/src/hooks/useModelProvider.ts | 14 +++ web-app/src/lib/utils.ts | 4 +- web-app/src/mock/data.ts | 26 ++++++ .../settings/providers/$providerName.tsx | 5 + web-app/src/services/providers.ts | 28 +++++- 10 files changed, 285 insertions(+), 6 deletions(-) create mode 100644 web-app/src/containers/dialogs/DeleteProvider.tsx diff --git a/extensions/inference-cortex-extension/resources/default_settings.json b/extensions/inference-cortex-extension/resources/default_settings.json index 79ca05527..5574128e5 100644 --- a/extensions/inference-cortex-extension/resources/default_settings.json +++ b/extensions/inference-cortex-extension/resources/default_settings.json @@ -93,5 +93,17 @@ "controllerProps": { "value": true } + }, + { + "key": "hugging-face-access-token", + "title": "Hugging Face Access Token", + "description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.", + "controllerType": "input", + "controllerProps": { + "value": "", + "placeholder": "hf_**********************************", + "type": "password", + "inputActions": ["unobscure", "copy"] + } } ] diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 49f4392af..01a777908 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -36,6 +36,7 @@ enum Settings { cache_type = 'cache_type', use_mmap = 'use_mmap', cpu_threads = 'cpu_threads', + huggingfaceToken = 'hugging-face-access-token', } type LoadedModelResponse = { data: { engine: string; id: string }[] } @@ -130,6 +131,13 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { ) if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number + const huggingfaceToken = await this.getSetting( + Settings.huggingfaceToken, + '' + ) + if (huggingfaceToken) { + this.updateCortexConfig({ huggingface_token: huggingfaceToken }) + } this.subscribeToEvents() window.addEventListener('beforeunload', () => { @@ -145,6 +153,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { super.onUnload() } + /** + * Subscribe to settings update and make change accordingly + * @param key + * @param value + */ onSettingUpdate(key: string, value: T): void { if (key === Settings.n_parallel && typeof value === 'string') { this.n_parallel = Number(value) ?? 1 @@ -161,6 +174,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { } else if (key === Settings.cpu_threads && typeof value === 'string') { const threads_number = Number(value) if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number + } else if (key === Settings.huggingfaceToken) { + this.updateCortexConfig({ huggingface_token: value }) } } @@ -253,6 +268,18 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { }) } + /** + * Update cortex config + * @param body + */ + private async updateCortexConfig(body: { + [key: string]: any + }): Promise { + return this.apiInstance() + .then((api) => api.patch('v1/configs', { json: body }).then(() => {})) + .catch((e) => console.debug(e)) + } + /** * Subscribe to cortex.cpp websocket events */ diff --git a/web-app/src/containers/ProvidersMenu.tsx b/web-app/src/containers/ProvidersMenu.tsx index 2b36f4f59..c1334978a 100644 --- a/web-app/src/containers/ProvidersMenu.tsx +++ b/web-app/src/containers/ProvidersMenu.tsx @@ -2,16 +2,48 @@ import { route } from '@/constants/routes' import { useModelProvider } from '@/hooks/useModelProvider' import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils' import { useNavigate, useMatches, Link } from '@tanstack/react-router' -import { IconArrowLeft } from '@tabler/icons-react' +import { IconArrowLeft, IconCirclePlus } from '@tabler/icons-react' +import { + Dialog, + DialogClose, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from '@/components/ui/dialog' +import { Input } from '@/components/ui/input' +import { Button } from '@/components/ui/button' +import { useCallback, useState } from 'react' +import { openAIProviderSettings } from '@/mock/data' const ProvidersMenu = ({ stepSetupRemoteProvider, }: { stepSetupRemoteProvider: boolean }) => { - const { providers } = useModelProvider() + const { providers, addProvider } = useModelProvider() const navigate = useNavigate() const matches = useMatches() + const [name, setName] = useState('') + const createProvider = useCallback(() => { + addProvider({ + provider: name, + active: true, + models: [], + settings: openAIProviderSettings as ProviderSetting[], + api_key: '', + base_url: 'https://api.openai.com/v1', + }) + setTimeout(() => { + navigate({ + to: route.settings.providers, + params: { + providerName: name, + }, + }) + }, 0) + }, [name, addProvider, navigate]) return (
@@ -65,6 +97,51 @@ const ProvidersMenu = ({
) })} + + + +
{}} + > + + + Add Provider + +
+
+ + + Add OpenAI Provider + setName(e.target.value)} + className="mt-2" + placeholder="Enter a name for your provider" + onKeyDown={(e) => { + // Prevent key from being captured by parent components + e.stopPropagation() + }} + /> + + + + + + + + + + +
) diff --git a/web-app/src/containers/dialogs/DeleteProvider.tsx b/web-app/src/containers/dialogs/DeleteProvider.tsx new file mode 100644 index 000000000..75bd1daf5 --- /dev/null +++ b/web-app/src/containers/dialogs/DeleteProvider.tsx @@ -0,0 +1,93 @@ +import { Button } from '@/components/ui/button' +import { + Dialog, + DialogClose, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from '@/components/ui/dialog' + +import { toast } from 'sonner' +import { CardItem } from '../Card' +import { models } from 'token.js' +import { EngineManager } from '@janhq/core' +import { useModelProvider } from '@/hooks/useModelProvider' +import { useRouter } from '@tanstack/react-router' +import { route } from '@/constants/routes' +import { normalizeProvider } from '@/lib/models' + +type Props = { + provider?: ProviderObject +} +const DeleteProvider = ({ provider }: Props) => { + const { deleteProvider, providers } = useModelProvider() + const router = useRouter() + if ( + !provider || + Object.keys(models).includes(provider.provider) || + EngineManager.instance().get(normalizeProvider(provider.provider)) + ) + return null + + const removeProvider = async () => { + deleteProvider(provider.provider) + toast.success('Delete Provider', { + id: `delete-provider-${provider.provider}`, + description: `Provider ${provider.provider} has been permanently deleted.`, + }) + setTimeout(() => { + router.navigate({ + to: route.settings.providers, + params: { + providerName: providers[0].provider, + }, + }) + }, 0) + } + + return ( + + + + + + + Delete Provider: {provider.provider} + + Are you sure you want to delete this provider? This action + cannot be undone. + + + + + + + + + + + + + + } + /> + ) +} +export default DeleteProvider diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 11d8c03e7..aa76edcb8 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -37,6 +37,7 @@ export const useChat = () => { const provider = useMemo(() => { return getProviderByName(selectedProvider) }, [selectedProvider, getProviderByName]) + const getCurrentThread = useCallback(async () => { let currentThread = retrieveThread() if (!currentThread) { diff --git a/web-app/src/hooks/useModelProvider.ts b/web-app/src/hooks/useModelProvider.ts index eadbb2f08..75c8bc4ba 100644 --- a/web-app/src/hooks/useModelProvider.ts +++ b/web-app/src/hooks/useModelProvider.ts @@ -13,6 +13,8 @@ type ModelProviderState = { providerName: string, modelName: string ) => Model | undefined + addProvider: (provider: ModelProvider) => void + deleteProvider: (providerName: string) => void deleteModel: (modelId: string) => void } @@ -114,6 +116,18 @@ export const useModelProvider = create()( }), })) }, + addProvider: (provider: ModelProvider) => { + set((state) => ({ + providers: [...state.providers, provider], + })) + }, + deleteProvider: (providerName: string) => { + set((state) => ({ + providers: state.providers.filter( + (provider) => provider.provider !== providerName + ), + })) + }, }), { name: localStorageKey.modelProvider, diff --git a/web-app/src/lib/utils.ts b/web-app/src/lib/utils.ts index 9d15c39cb..e43f98649 100644 --- a/web-app/src/lib/utils.ts +++ b/web-app/src/lib/utils.ts @@ -17,8 +17,6 @@ export function getProviderLogo(provider: string) { return '/images/model-provider/martian.svg' case 'openrouter': return '/images/model-provider/openRouter.svg' - case 'openai': - return '/images/model-provider/openai.svg' case 'groq': return '/images/model-provider/groq.svg' case 'cohere': @@ -32,7 +30,7 @@ export function getProviderLogo(provider: string) { case 'deepseek': return '/images/model-provider/deepseek.svg' default: - return undefined + return '/images/model-provider/openai.svg' } } diff --git a/web-app/src/mock/data.ts b/web-app/src/mock/data.ts index f09cb9ce3..415052984 100644 --- a/web-app/src/mock/data.ts +++ b/web-app/src/mock/data.ts @@ -1,3 +1,29 @@ +export const openAIProviderSettings = [ + { + key: 'api-key', + title: 'API Key', + description: + "The OpenAI API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.", + controller_type: 'input', + controller_props: { + placeholder: 'Insert API Key', + value: '', + type: 'password', + input_actions: ['unobscure', 'copy'], + }, + }, + { + key: 'base-url', + title: 'Base URL', + description: + 'The base endpoint to use. See the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create) for more information.', + controller_type: 'input', + controller_props: { + placeholder: 'https://api.openai.com/v1', + value: 'https://api.openai.com/v1', + }, + }, +] export const mockModelProvider = [ // { // active: true, diff --git a/web-app/src/routes/settings/providers/$providerName.tsx b/web-app/src/routes/settings/providers/$providerName.tsx index 3e8e8a0df..6290e803b 100644 --- a/web-app/src/routes/settings/providers/$providerName.tsx +++ b/web-app/src/routes/settings/providers/$providerName.tsx @@ -22,6 +22,8 @@ import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel' import Joyride, { CallBackProps, STATUS } from 'react-joyride' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { route } from '@/constants/routes' +import DeleteProvider from '@/containers/dialogs/DeleteProvider' +import { updateSettings } from '@/services/providers' // as route.threadsDetail export const Route = createFileRoute('/settings/providers/$providerName')({ @@ -165,6 +167,7 @@ function ProviderDetail() { ) { updateObj.base_url = newValue } + updateSettings(providerName, updateObj.settings ?? []) updateProvider(providerName, { ...provider, ...updateObj, @@ -214,6 +217,8 @@ function ProviderDetail() { /> ) })} + + {/* Models */} diff --git a/web-app/src/services/providers.ts b/web-app/src/services/providers.ts index 4c3b9e7b7..abea1c3a0 100644 --- a/web-app/src/services/providers.ts +++ b/web-app/src/services/providers.ts @@ -1,9 +1,10 @@ import { models as providerModels } from 'token.js' import { mockModelProvider } from '@/mock/data' -import { EngineManager } from '@janhq/core' +import { EngineManager, SettingComponentProps } from '@janhq/core' import { ModelCapabilities } from '@/types/models' import { modelSettings } from '@/lib/predefined' import { fetchModels } from './models' +import { ExtensionManager } from '@/lib/extension' export const getProviders = async (): Promise => { const builtinProviders = mockModelProvider.map((provider) => { @@ -84,3 +85,28 @@ export const getProviders = async (): Promise => { return runtimeProviders.concat(builtinProviders as ModelProvider[]) } + +/** + * Update the settings of a provider extension. + * TODO: Later on we don't retrieve this using provider name + * @param providerName + * @param settings + */ +export const updateSettings = async ( + providerName: string, + settings: ProviderSetting[] +): Promise => { + const provider = providerName === 'llama.cpp' ? 'cortex' : providerName + return ExtensionManager.getInstance() + .getEngine(provider) + ?.updateSettings( + settings.map((setting) => ({ + ...setting, + controllerProps: { + ...setting.controller_props, + value: setting.controller_props.value ?? '', + }, + controllerType: setting.controller_type, + })) as SettingComponentProps[] + ) +}