feat: add custom OpenAI provider (#5033)

* feat: add custom OpenAI provider

* chore: add HF token setting

* chore: move HF token setting to llama.cpp provider - later deprecate model extension
This commit is contained in:
Louis 2025-05-20 14:30:51 +07:00 committed by GitHub
parent 46943a1cf7
commit d5393e4563
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 285 additions and 6 deletions

View File

@ -93,5 +93,17 @@
"controllerProps": { "controllerProps": {
"value": true "value": true
} }
},
{
"key": "hugging-face-access-token",
"title": "Hugging Face Access Token",
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
"controllerType": "input",
"controllerProps": {
"value": "",
"placeholder": "hf_**********************************",
"type": "password",
"inputActions": ["unobscure", "copy"]
}
} }
] ]

View File

@ -36,6 +36,7 @@ enum Settings {
cache_type = 'cache_type', cache_type = 'cache_type',
use_mmap = 'use_mmap', use_mmap = 'use_mmap',
cpu_threads = 'cpu_threads', cpu_threads = 'cpu_threads',
huggingfaceToken = 'hugging-face-access-token',
} }
type LoadedModelResponse = { data: { engine: string; id: string }[] } type LoadedModelResponse = { data: { engine: string; id: string }[] }
@ -130,6 +131,13 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
) )
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
const huggingfaceToken = await this.getSetting<string>(
Settings.huggingfaceToken,
''
)
if (huggingfaceToken) {
this.updateCortexConfig({ huggingface_token: huggingfaceToken })
}
this.subscribeToEvents() this.subscribeToEvents()
window.addEventListener('beforeunload', () => { window.addEventListener('beforeunload', () => {
@ -145,6 +153,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
super.onUnload() super.onUnload()
} }
/**
* Subscribe to settings update and make change accordingly
* @param key
* @param value
*/
onSettingUpdate<T>(key: string, value: T): void { onSettingUpdate<T>(key: string, value: T): void {
if (key === Settings.n_parallel && typeof value === 'string') { if (key === Settings.n_parallel && typeof value === 'string') {
this.n_parallel = Number(value) ?? 1 this.n_parallel = Number(value) ?? 1
@ -161,6 +174,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
} else if (key === Settings.cpu_threads && typeof value === 'string') { } else if (key === Settings.cpu_threads && typeof value === 'string') {
const threads_number = Number(value) const threads_number = Number(value)
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
} else if (key === Settings.huggingfaceToken) {
this.updateCortexConfig({ huggingface_token: value })
} }
} }
@ -253,6 +268,18 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
}) })
} }
/**
* Update cortex config
* @param body
*/
private async updateCortexConfig(body: {
[key: string]: any
}): Promise<void> {
return this.apiInstance()
.then((api) => api.patch('v1/configs', { json: body }).then(() => {}))
.catch((e) => console.debug(e))
}
/** /**
* Subscribe to cortex.cpp websocket events * Subscribe to cortex.cpp websocket events
*/ */

View File

@ -2,16 +2,48 @@ import { route } from '@/constants/routes'
import { useModelProvider } from '@/hooks/useModelProvider' import { useModelProvider } from '@/hooks/useModelProvider'
import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils' import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils'
import { useNavigate, useMatches, Link } from '@tanstack/react-router' import { useNavigate, useMatches, Link } from '@tanstack/react-router'
import { IconArrowLeft } from '@tabler/icons-react' import { IconArrowLeft, IconCirclePlus } from '@tabler/icons-react'
import {
Dialog,
DialogClose,
DialogContent,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { Input } from '@/components/ui/input'
import { Button } from '@/components/ui/button'
import { useCallback, useState } from 'react'
import { openAIProviderSettings } from '@/mock/data'
const ProvidersMenu = ({ const ProvidersMenu = ({
stepSetupRemoteProvider, stepSetupRemoteProvider,
}: { }: {
stepSetupRemoteProvider: boolean stepSetupRemoteProvider: boolean
}) => { }) => {
const { providers } = useModelProvider() const { providers, addProvider } = useModelProvider()
const navigate = useNavigate() const navigate = useNavigate()
const matches = useMatches() const matches = useMatches()
const [name, setName] = useState('')
const createProvider = useCallback(() => {
addProvider({
provider: name,
active: true,
models: [],
settings: openAIProviderSettings as ProviderSetting[],
api_key: '',
base_url: 'https://api.openai.com/v1',
})
setTimeout(() => {
navigate({
to: route.settings.providers,
params: {
providerName: name,
},
})
}, 0)
}, [name, addProvider, navigate])
return ( return (
<div className="w-44 py-2 border-r border-main-view-fg/5 pb-10 overflow-y-auto"> <div className="w-44 py-2 border-r border-main-view-fg/5 pb-10 overflow-y-auto">
@ -65,6 +97,51 @@ const ProvidersMenu = ({
</div> </div>
) )
})} })}
<Dialog>
<DialogTrigger asChild>
<div
className="bg-main-view flex cursor-pointer px-4 my-1.5 items-center gap-1.5 text-main-view-fg/80"
onClick={() => {}}
>
<IconCirclePlus size={16} />
<span className="capitalize">
Add Provider
</span>
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Add OpenAI Provider</DialogTitle>
<Input
value={name}
onChange={(e) => setName(e.target.value)}
className="mt-2"
placeholder="Enter a name for your provider"
onKeyDown={(e) => {
// Prevent key from being captured by parent components
e.stopPropagation()
}}
/>
<DialogFooter className="mt-2 flex items-center">
<DialogClose asChild>
<Button
variant="link"
size="sm"
className="hover:no-underline"
>
Cancel
</Button>
</DialogClose>
<DialogClose asChild>
<Button disabled={!name} onClick={createProvider}>
Create
</Button>
</DialogClose>
</DialogFooter>
</DialogHeader>
</DialogContent>
</Dialog>
</div> </div>
</div> </div>
) )

View File

@ -0,0 +1,93 @@
import { Button } from '@/components/ui/button'
import {
Dialog,
DialogClose,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { toast } from 'sonner'
import { CardItem } from '../Card'
import { models } from 'token.js'
import { EngineManager } from '@janhq/core'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useRouter } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { normalizeProvider } from '@/lib/models'
type Props = {
provider?: ProviderObject
}
const DeleteProvider = ({ provider }: Props) => {
const { deleteProvider, providers } = useModelProvider()
const router = useRouter()
if (
!provider ||
Object.keys(models).includes(provider.provider) ||
EngineManager.instance().get(normalizeProvider(provider.provider))
)
return null
const removeProvider = async () => {
deleteProvider(provider.provider)
toast.success('Delete Provider', {
id: `delete-provider-${provider.provider}`,
description: `Provider ${provider.provider} has been permanently deleted.`,
})
setTimeout(() => {
router.navigate({
to: route.settings.providers,
params: {
providerName: providers[0].provider,
},
})
}, 0)
}
return (
<CardItem
title="Delete Provider"
description="Delete this provider and all its models. This action cannot be undone."
actions={
<Dialog>
<DialogTrigger asChild>
<Button variant="destructive" size="sm">
Delete
</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Delete Provider: {provider.provider}</DialogTitle>
<DialogDescription>
Are you sure you want to delete this provider? This action
cannot be undone.
</DialogDescription>
</DialogHeader>
<DialogFooter className="mt-2">
<DialogClose asChild>
<Button variant="link" size="sm" className="hover:no-underline">
Cancel
</Button>
</DialogClose>
<DialogClose asChild>
<Button
variant="destructive"
size="sm"
onClick={removeProvider}
>
Delete
</Button>
</DialogClose>
</DialogFooter>
</DialogContent>
</Dialog>
}
/>
)
}
export default DeleteProvider

View File

@ -37,6 +37,7 @@ export const useChat = () => {
const provider = useMemo(() => { const provider = useMemo(() => {
return getProviderByName(selectedProvider) return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName]) }, [selectedProvider, getProviderByName])
const getCurrentThread = useCallback(async () => { const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread() let currentThread = retrieveThread()
if (!currentThread) { if (!currentThread) {

View File

@ -13,6 +13,8 @@ type ModelProviderState = {
providerName: string, providerName: string,
modelName: string modelName: string
) => Model | undefined ) => Model | undefined
addProvider: (provider: ModelProvider) => void
deleteProvider: (providerName: string) => void
deleteModel: (modelId: string) => void deleteModel: (modelId: string) => void
} }
@ -114,6 +116,18 @@ export const useModelProvider = create<ModelProviderState>()(
}), }),
})) }))
}, },
addProvider: (provider: ModelProvider) => {
set((state) => ({
providers: [...state.providers, provider],
}))
},
deleteProvider: (providerName: string) => {
set((state) => ({
providers: state.providers.filter(
(provider) => provider.provider !== providerName
),
}))
},
}), }),
{ {
name: localStorageKey.modelProvider, name: localStorageKey.modelProvider,

View File

@ -17,8 +17,6 @@ export function getProviderLogo(provider: string) {
return '/images/model-provider/martian.svg' return '/images/model-provider/martian.svg'
case 'openrouter': case 'openrouter':
return '/images/model-provider/openRouter.svg' return '/images/model-provider/openRouter.svg'
case 'openai':
return '/images/model-provider/openai.svg'
case 'groq': case 'groq':
return '/images/model-provider/groq.svg' return '/images/model-provider/groq.svg'
case 'cohere': case 'cohere':
@ -32,7 +30,7 @@ export function getProviderLogo(provider: string) {
case 'deepseek': case 'deepseek':
return '/images/model-provider/deepseek.svg' return '/images/model-provider/deepseek.svg'
default: default:
return undefined return '/images/model-provider/openai.svg'
} }
} }

View File

@ -1,3 +1,29 @@
export const openAIProviderSettings = [
{
key: 'api-key',
title: 'API Key',
description:
"The OpenAI API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.",
controller_type: 'input',
controller_props: {
placeholder: 'Insert API Key',
value: '',
type: 'password',
input_actions: ['unobscure', 'copy'],
},
},
{
key: 'base-url',
title: 'Base URL',
description:
'The base endpoint to use. See the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create) for more information.',
controller_type: 'input',
controller_props: {
placeholder: 'https://api.openai.com/v1',
value: 'https://api.openai.com/v1',
},
},
]
export const mockModelProvider = [ export const mockModelProvider = [
// { // {
// active: true, // active: true,

View File

@ -22,6 +22,8 @@ import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
import Joyride, { CallBackProps, STATUS } from 'react-joyride' import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import { route } from '@/constants/routes' import { route } from '@/constants/routes'
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
import { updateSettings } from '@/services/providers'
// as route.threadsDetail // as route.threadsDetail
export const Route = createFileRoute('/settings/providers/$providerName')({ export const Route = createFileRoute('/settings/providers/$providerName')({
@ -165,6 +167,7 @@ function ProviderDetail() {
) { ) {
updateObj.base_url = newValue updateObj.base_url = newValue
} }
updateSettings(providerName, updateObj.settings ?? [])
updateProvider(providerName, { updateProvider(providerName, {
...provider, ...provider,
...updateObj, ...updateObj,
@ -214,6 +217,8 @@ function ProviderDetail() {
/> />
) )
})} })}
<DeleteProvider provider={provider} />
</Card> </Card>
{/* Models */} {/* Models */}

View File

@ -1,9 +1,10 @@
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { mockModelProvider } from '@/mock/data' import { mockModelProvider } from '@/mock/data'
import { EngineManager } from '@janhq/core' import { EngineManager, SettingComponentProps } from '@janhq/core'
import { ModelCapabilities } from '@/types/models' import { ModelCapabilities } from '@/types/models'
import { modelSettings } from '@/lib/predefined' import { modelSettings } from '@/lib/predefined'
import { fetchModels } from './models' import { fetchModels } from './models'
import { ExtensionManager } from '@/lib/extension'
export const getProviders = async (): Promise<ModelProvider[]> => { export const getProviders = async (): Promise<ModelProvider[]> => {
const builtinProviders = mockModelProvider.map((provider) => { const builtinProviders = mockModelProvider.map((provider) => {
@ -84,3 +85,28 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
return runtimeProviders.concat(builtinProviders as ModelProvider[]) return runtimeProviders.concat(builtinProviders as ModelProvider[])
} }
/**
* Update the settings of a provider extension.
* TODO: Later on we don't retrieve this using provider name
* @param providerName
* @param settings
*/
export const updateSettings = async (
providerName: string,
settings: ProviderSetting[]
): Promise<void> => {
const provider = providerName === 'llama.cpp' ? 'cortex' : providerName
return ExtensionManager.getInstance()
.getEngine(provider)
?.updateSettings(
settings.map((setting) => ({
...setting,
controllerProps: {
...setting.controller_props,
value: setting.controller_props.value ?? '',
},
controllerType: setting.controller_type,
})) as SettingComponentProps[]
)
}