feat: add custom OpenAI provider (#5033)
* feat: add custom OpenAI provider * chore: add HF token setting * chore: move HF token setting to llama.cpp provider - later deprecate model extension
This commit is contained in:
parent
46943a1cf7
commit
d5393e4563
@ -93,5 +93,17 @@
|
||||
"controllerProps": {
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "hugging-face-access-token",
|
||||
"title": "Hugging Face Access Token",
|
||||
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
|
||||
"controllerType": "input",
|
||||
"controllerProps": {
|
||||
"value": "",
|
||||
"placeholder": "hf_**********************************",
|
||||
"type": "password",
|
||||
"inputActions": ["unobscure", "copy"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@ -36,6 +36,7 @@ enum Settings {
|
||||
cache_type = 'cache_type',
|
||||
use_mmap = 'use_mmap',
|
||||
cpu_threads = 'cpu_threads',
|
||||
huggingfaceToken = 'hugging-face-access-token',
|
||||
}
|
||||
|
||||
type LoadedModelResponse = { data: { engine: string; id: string }[] }
|
||||
@ -130,6 +131,13 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
)
|
||||
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
||||
|
||||
const huggingfaceToken = await this.getSetting<string>(
|
||||
Settings.huggingfaceToken,
|
||||
''
|
||||
)
|
||||
if (huggingfaceToken) {
|
||||
this.updateCortexConfig({ huggingface_token: huggingfaceToken })
|
||||
}
|
||||
this.subscribeToEvents()
|
||||
|
||||
window.addEventListener('beforeunload', () => {
|
||||
@ -145,6 +153,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
super.onUnload()
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to settings update and make change accordingly
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
if (key === Settings.n_parallel && typeof value === 'string') {
|
||||
this.n_parallel = Number(value) ?? 1
|
||||
@ -161,6 +174,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
} else if (key === Settings.cpu_threads && typeof value === 'string') {
|
||||
const threads_number = Number(value)
|
||||
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
||||
} else if (key === Settings.huggingfaceToken) {
|
||||
this.updateCortexConfig({ huggingface_token: value })
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,6 +268,18 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Update cortex config
|
||||
* @param body
|
||||
*/
|
||||
private async updateCortexConfig(body: {
|
||||
[key: string]: any
|
||||
}): Promise<void> {
|
||||
return this.apiInstance()
|
||||
.then((api) => api.patch('v1/configs', { json: body }).then(() => {}))
|
||||
.catch((e) => console.debug(e))
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to cortex.cpp websocket events
|
||||
*/
|
||||
|
||||
@ -2,16 +2,48 @@ import { route } from '@/constants/routes'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils'
|
||||
import { useNavigate, useMatches, Link } from '@tanstack/react-router'
|
||||
import { IconArrowLeft } from '@tabler/icons-react'
|
||||
import { IconArrowLeft, IconCirclePlus } from '@tabler/icons-react'
|
||||
import {
|
||||
Dialog,
|
||||
DialogClose,
|
||||
DialogContent,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { openAIProviderSettings } from '@/mock/data'
|
||||
|
||||
const ProvidersMenu = ({
|
||||
stepSetupRemoteProvider,
|
||||
}: {
|
||||
stepSetupRemoteProvider: boolean
|
||||
}) => {
|
||||
const { providers } = useModelProvider()
|
||||
const { providers, addProvider } = useModelProvider()
|
||||
const navigate = useNavigate()
|
||||
const matches = useMatches()
|
||||
const [name, setName] = useState('')
|
||||
const createProvider = useCallback(() => {
|
||||
addProvider({
|
||||
provider: name,
|
||||
active: true,
|
||||
models: [],
|
||||
settings: openAIProviderSettings as ProviderSetting[],
|
||||
api_key: '',
|
||||
base_url: 'https://api.openai.com/v1',
|
||||
})
|
||||
setTimeout(() => {
|
||||
navigate({
|
||||
to: route.settings.providers,
|
||||
params: {
|
||||
providerName: name,
|
||||
},
|
||||
})
|
||||
}, 0)
|
||||
}, [name, addProvider, navigate])
|
||||
|
||||
return (
|
||||
<div className="w-44 py-2 border-r border-main-view-fg/5 pb-10 overflow-y-auto">
|
||||
@ -65,6 +97,51 @@ const ProvidersMenu = ({
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<div
|
||||
className="bg-main-view flex cursor-pointer px-4 my-1.5 items-center gap-1.5 text-main-view-fg/80"
|
||||
onClick={() => {}}
|
||||
>
|
||||
<IconCirclePlus size={16} />
|
||||
<span className="capitalize">
|
||||
Add Provider
|
||||
</span>
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add OpenAI Provider</DialogTitle>
|
||||
<Input
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
className="mt-2"
|
||||
placeholder="Enter a name for your provider"
|
||||
onKeyDown={(e) => {
|
||||
// Prevent key from being captured by parent components
|
||||
e.stopPropagation()
|
||||
}}
|
||||
/>
|
||||
<DialogFooter className="mt-2 flex items-center">
|
||||
<DialogClose asChild>
|
||||
<Button
|
||||
variant="link"
|
||||
size="sm"
|
||||
className="hover:no-underline"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</DialogClose>
|
||||
<DialogClose asChild>
|
||||
<Button disabled={!name} onClick={createProvider}>
|
||||
Create
|
||||
</Button>
|
||||
</DialogClose>
|
||||
</DialogFooter>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
93
web-app/src/containers/dialogs/DeleteProvider.tsx
Normal file
93
web-app/src/containers/dialogs/DeleteProvider.tsx
Normal file
@ -0,0 +1,93 @@
|
||||
import { Button } from '@/components/ui/button'
|
||||
import {
|
||||
Dialog,
|
||||
DialogClose,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
|
||||
import { toast } from 'sonner'
|
||||
import { CardItem } from '../Card'
|
||||
import { models } from 'token.js'
|
||||
import { EngineManager } from '@janhq/core'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { useRouter } from '@tanstack/react-router'
|
||||
import { route } from '@/constants/routes'
|
||||
import { normalizeProvider } from '@/lib/models'
|
||||
|
||||
type Props = {
|
||||
provider?: ProviderObject
|
||||
}
|
||||
const DeleteProvider = ({ provider }: Props) => {
|
||||
const { deleteProvider, providers } = useModelProvider()
|
||||
const router = useRouter()
|
||||
if (
|
||||
!provider ||
|
||||
Object.keys(models).includes(provider.provider) ||
|
||||
EngineManager.instance().get(normalizeProvider(provider.provider))
|
||||
)
|
||||
return null
|
||||
|
||||
const removeProvider = async () => {
|
||||
deleteProvider(provider.provider)
|
||||
toast.success('Delete Provider', {
|
||||
id: `delete-provider-${provider.provider}`,
|
||||
description: `Provider ${provider.provider} has been permanently deleted.`,
|
||||
})
|
||||
setTimeout(() => {
|
||||
router.navigate({
|
||||
to: route.settings.providers,
|
||||
params: {
|
||||
providerName: providers[0].provider,
|
||||
},
|
||||
})
|
||||
}, 0)
|
||||
}
|
||||
|
||||
return (
|
||||
<CardItem
|
||||
title="Delete Provider"
|
||||
description="Delete this provider and all its models. This action cannot be undone."
|
||||
actions={
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<Button variant="destructive" size="sm">
|
||||
Delete
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Delete Provider: {provider.provider}</DialogTitle>
|
||||
<DialogDescription>
|
||||
Are you sure you want to delete this provider? This action
|
||||
cannot be undone.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<DialogFooter className="mt-2">
|
||||
<DialogClose asChild>
|
||||
<Button variant="link" size="sm" className="hover:no-underline">
|
||||
Cancel
|
||||
</Button>
|
||||
</DialogClose>
|
||||
<DialogClose asChild>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={removeProvider}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</DialogClose>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
export default DeleteProvider
|
||||
@ -37,6 +37,7 @@ export const useChat = () => {
|
||||
const provider = useMemo(() => {
|
||||
return getProviderByName(selectedProvider)
|
||||
}, [selectedProvider, getProviderByName])
|
||||
|
||||
const getCurrentThread = useCallback(async () => {
|
||||
let currentThread = retrieveThread()
|
||||
if (!currentThread) {
|
||||
|
||||
@ -13,6 +13,8 @@ type ModelProviderState = {
|
||||
providerName: string,
|
||||
modelName: string
|
||||
) => Model | undefined
|
||||
addProvider: (provider: ModelProvider) => void
|
||||
deleteProvider: (providerName: string) => void
|
||||
deleteModel: (modelId: string) => void
|
||||
}
|
||||
|
||||
@ -114,6 +116,18 @@ export const useModelProvider = create<ModelProviderState>()(
|
||||
}),
|
||||
}))
|
||||
},
|
||||
addProvider: (provider: ModelProvider) => {
|
||||
set((state) => ({
|
||||
providers: [...state.providers, provider],
|
||||
}))
|
||||
},
|
||||
deleteProvider: (providerName: string) => {
|
||||
set((state) => ({
|
||||
providers: state.providers.filter(
|
||||
(provider) => provider.provider !== providerName
|
||||
),
|
||||
}))
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStorageKey.modelProvider,
|
||||
|
||||
@ -17,8 +17,6 @@ export function getProviderLogo(provider: string) {
|
||||
return '/images/model-provider/martian.svg'
|
||||
case 'openrouter':
|
||||
return '/images/model-provider/openRouter.svg'
|
||||
case 'openai':
|
||||
return '/images/model-provider/openai.svg'
|
||||
case 'groq':
|
||||
return '/images/model-provider/groq.svg'
|
||||
case 'cohere':
|
||||
@ -32,7 +30,7 @@ export function getProviderLogo(provider: string) {
|
||||
case 'deepseek':
|
||||
return '/images/model-provider/deepseek.svg'
|
||||
default:
|
||||
return undefined
|
||||
return '/images/model-provider/openai.svg'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,3 +1,29 @@
|
||||
export const openAIProviderSettings = [
|
||||
{
|
||||
key: 'api-key',
|
||||
title: 'API Key',
|
||||
description:
|
||||
"The OpenAI API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.",
|
||||
controller_type: 'input',
|
||||
controller_props: {
|
||||
placeholder: 'Insert API Key',
|
||||
value: '',
|
||||
type: 'password',
|
||||
input_actions: ['unobscure', 'copy'],
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'base-url',
|
||||
title: 'Base URL',
|
||||
description:
|
||||
'The base endpoint to use. See the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create) for more information.',
|
||||
controller_type: 'input',
|
||||
controller_props: {
|
||||
placeholder: 'https://api.openai.com/v1',
|
||||
value: 'https://api.openai.com/v1',
|
||||
},
|
||||
},
|
||||
]
|
||||
export const mockModelProvider = [
|
||||
// {
|
||||
// active: true,
|
||||
|
||||
@ -22,6 +22,8 @@ import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||
import { route } from '@/constants/routes'
|
||||
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
|
||||
import { updateSettings } from '@/services/providers'
|
||||
|
||||
// as route.threadsDetail
|
||||
export const Route = createFileRoute('/settings/providers/$providerName')({
|
||||
@ -165,6 +167,7 @@ function ProviderDetail() {
|
||||
) {
|
||||
updateObj.base_url = newValue
|
||||
}
|
||||
updateSettings(providerName, updateObj.settings ?? [])
|
||||
updateProvider(providerName, {
|
||||
...provider,
|
||||
...updateObj,
|
||||
@ -214,6 +217,8 @@ function ProviderDetail() {
|
||||
/>
|
||||
)
|
||||
})}
|
||||
|
||||
<DeleteProvider provider={provider} />
|
||||
</Card>
|
||||
|
||||
{/* Models */}
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import { models as providerModels } from 'token.js'
|
||||
import { mockModelProvider } from '@/mock/data'
|
||||
import { EngineManager } from '@janhq/core'
|
||||
import { EngineManager, SettingComponentProps } from '@janhq/core'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
import { modelSettings } from '@/lib/predefined'
|
||||
import { fetchModels } from './models'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
|
||||
export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
const builtinProviders = mockModelProvider.map((provider) => {
|
||||
@ -84,3 +85,28 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
|
||||
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the settings of a provider extension.
|
||||
* TODO: Later on we don't retrieve this using provider name
|
||||
* @param providerName
|
||||
* @param settings
|
||||
*/
|
||||
export const updateSettings = async (
|
||||
providerName: string,
|
||||
settings: ProviderSetting[]
|
||||
): Promise<void> => {
|
||||
const provider = providerName === 'llama.cpp' ? 'cortex' : providerName
|
||||
return ExtensionManager.getInstance()
|
||||
.getEngine(provider)
|
||||
?.updateSettings(
|
||||
settings.map((setting) => ({
|
||||
...setting,
|
||||
controllerProps: {
|
||||
...setting.controller_props,
|
||||
value: setting.controller_props.value ?? '',
|
||||
},
|
||||
controllerType: setting.controller_type,
|
||||
})) as SettingComponentProps[]
|
||||
)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user