feat: add custom OpenAI provider (#5033)
* feat: add custom OpenAI provider * chore: add HF token setting * chore: move HF token setting to llama.cpp provider - later deprecate model extension
This commit is contained in:
parent
46943a1cf7
commit
d5393e4563
@ -93,5 +93,17 @@
|
|||||||
"controllerProps": {
|
"controllerProps": {
|
||||||
"value": true
|
"value": true
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "hugging-face-access-token",
|
||||||
|
"title": "Hugging Face Access Token",
|
||||||
|
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
|
||||||
|
"controllerType": "input",
|
||||||
|
"controllerProps": {
|
||||||
|
"value": "",
|
||||||
|
"placeholder": "hf_**********************************",
|
||||||
|
"type": "password",
|
||||||
|
"inputActions": ["unobscure", "copy"]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@ -36,6 +36,7 @@ enum Settings {
|
|||||||
cache_type = 'cache_type',
|
cache_type = 'cache_type',
|
||||||
use_mmap = 'use_mmap',
|
use_mmap = 'use_mmap',
|
||||||
cpu_threads = 'cpu_threads',
|
cpu_threads = 'cpu_threads',
|
||||||
|
huggingfaceToken = 'hugging-face-access-token',
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoadedModelResponse = { data: { engine: string; id: string }[] }
|
type LoadedModelResponse = { data: { engine: string; id: string }[] }
|
||||||
@ -130,6 +131,13 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
)
|
)
|
||||||
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
||||||
|
|
||||||
|
const huggingfaceToken = await this.getSetting<string>(
|
||||||
|
Settings.huggingfaceToken,
|
||||||
|
''
|
||||||
|
)
|
||||||
|
if (huggingfaceToken) {
|
||||||
|
this.updateCortexConfig({ huggingface_token: huggingfaceToken })
|
||||||
|
}
|
||||||
this.subscribeToEvents()
|
this.subscribeToEvents()
|
||||||
|
|
||||||
window.addEventListener('beforeunload', () => {
|
window.addEventListener('beforeunload', () => {
|
||||||
@ -145,6 +153,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
super.onUnload()
|
super.onUnload()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subscribe to settings update and make change accordingly
|
||||||
|
* @param key
|
||||||
|
* @param value
|
||||||
|
*/
|
||||||
onSettingUpdate<T>(key: string, value: T): void {
|
onSettingUpdate<T>(key: string, value: T): void {
|
||||||
if (key === Settings.n_parallel && typeof value === 'string') {
|
if (key === Settings.n_parallel && typeof value === 'string') {
|
||||||
this.n_parallel = Number(value) ?? 1
|
this.n_parallel = Number(value) ?? 1
|
||||||
@ -161,6 +174,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
} else if (key === Settings.cpu_threads && typeof value === 'string') {
|
} else if (key === Settings.cpu_threads && typeof value === 'string') {
|
||||||
const threads_number = Number(value)
|
const threads_number = Number(value)
|
||||||
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
|
||||||
|
} else if (key === Settings.huggingfaceToken) {
|
||||||
|
this.updateCortexConfig({ huggingface_token: value })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,6 +268,18 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update cortex config
|
||||||
|
* @param body
|
||||||
|
*/
|
||||||
|
private async updateCortexConfig(body: {
|
||||||
|
[key: string]: any
|
||||||
|
}): Promise<void> {
|
||||||
|
return this.apiInstance()
|
||||||
|
.then((api) => api.patch('v1/configs', { json: body }).then(() => {}))
|
||||||
|
.catch((e) => console.debug(e))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subscribe to cortex.cpp websocket events
|
* Subscribe to cortex.cpp websocket events
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -2,16 +2,48 @@ import { route } from '@/constants/routes'
|
|||||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||||
import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils'
|
import { cn, getProviderLogo, getProviderTitle } from '@/lib/utils'
|
||||||
import { useNavigate, useMatches, Link } from '@tanstack/react-router'
|
import { useNavigate, useMatches, Link } from '@tanstack/react-router'
|
||||||
import { IconArrowLeft } from '@tabler/icons-react'
|
import { IconArrowLeft, IconCirclePlus } from '@tabler/icons-react'
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogClose,
|
||||||
|
DialogContent,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
DialogTrigger,
|
||||||
|
} from '@/components/ui/dialog'
|
||||||
|
import { Input } from '@/components/ui/input'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { useCallback, useState } from 'react'
|
||||||
|
import { openAIProviderSettings } from '@/mock/data'
|
||||||
|
|
||||||
const ProvidersMenu = ({
|
const ProvidersMenu = ({
|
||||||
stepSetupRemoteProvider,
|
stepSetupRemoteProvider,
|
||||||
}: {
|
}: {
|
||||||
stepSetupRemoteProvider: boolean
|
stepSetupRemoteProvider: boolean
|
||||||
}) => {
|
}) => {
|
||||||
const { providers } = useModelProvider()
|
const { providers, addProvider } = useModelProvider()
|
||||||
const navigate = useNavigate()
|
const navigate = useNavigate()
|
||||||
const matches = useMatches()
|
const matches = useMatches()
|
||||||
|
const [name, setName] = useState('')
|
||||||
|
const createProvider = useCallback(() => {
|
||||||
|
addProvider({
|
||||||
|
provider: name,
|
||||||
|
active: true,
|
||||||
|
models: [],
|
||||||
|
settings: openAIProviderSettings as ProviderSetting[],
|
||||||
|
api_key: '',
|
||||||
|
base_url: 'https://api.openai.com/v1',
|
||||||
|
})
|
||||||
|
setTimeout(() => {
|
||||||
|
navigate({
|
||||||
|
to: route.settings.providers,
|
||||||
|
params: {
|
||||||
|
providerName: name,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}, 0)
|
||||||
|
}, [name, addProvider, navigate])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-44 py-2 border-r border-main-view-fg/5 pb-10 overflow-y-auto">
|
<div className="w-44 py-2 border-r border-main-view-fg/5 pb-10 overflow-y-auto">
|
||||||
@ -65,6 +97,51 @@ const ProvidersMenu = ({
|
|||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
|
|
||||||
|
<Dialog>
|
||||||
|
<DialogTrigger asChild>
|
||||||
|
<div
|
||||||
|
className="bg-main-view flex cursor-pointer px-4 my-1.5 items-center gap-1.5 text-main-view-fg/80"
|
||||||
|
onClick={() => {}}
|
||||||
|
>
|
||||||
|
<IconCirclePlus size={16} />
|
||||||
|
<span className="capitalize">
|
||||||
|
Add Provider
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</DialogTrigger>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>Add OpenAI Provider</DialogTitle>
|
||||||
|
<Input
|
||||||
|
value={name}
|
||||||
|
onChange={(e) => setName(e.target.value)}
|
||||||
|
className="mt-2"
|
||||||
|
placeholder="Enter a name for your provider"
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
// Prevent key from being captured by parent components
|
||||||
|
e.stopPropagation()
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<DialogFooter className="mt-2 flex items-center">
|
||||||
|
<DialogClose asChild>
|
||||||
|
<Button
|
||||||
|
variant="link"
|
||||||
|
size="sm"
|
||||||
|
className="hover:no-underline"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</DialogClose>
|
||||||
|
<DialogClose asChild>
|
||||||
|
<Button disabled={!name} onClick={createProvider}>
|
||||||
|
Create
|
||||||
|
</Button>
|
||||||
|
</DialogClose>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogHeader>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
|||||||
93
web-app/src/containers/dialogs/DeleteProvider.tsx
Normal file
93
web-app/src/containers/dialogs/DeleteProvider.tsx
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogClose,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
DialogTrigger,
|
||||||
|
} from '@/components/ui/dialog'
|
||||||
|
|
||||||
|
import { toast } from 'sonner'
|
||||||
|
import { CardItem } from '../Card'
|
||||||
|
import { models } from 'token.js'
|
||||||
|
import { EngineManager } from '@janhq/core'
|
||||||
|
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||||
|
import { useRouter } from '@tanstack/react-router'
|
||||||
|
import { route } from '@/constants/routes'
|
||||||
|
import { normalizeProvider } from '@/lib/models'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
provider?: ProviderObject
|
||||||
|
}
|
||||||
|
const DeleteProvider = ({ provider }: Props) => {
|
||||||
|
const { deleteProvider, providers } = useModelProvider()
|
||||||
|
const router = useRouter()
|
||||||
|
if (
|
||||||
|
!provider ||
|
||||||
|
Object.keys(models).includes(provider.provider) ||
|
||||||
|
EngineManager.instance().get(normalizeProvider(provider.provider))
|
||||||
|
)
|
||||||
|
return null
|
||||||
|
|
||||||
|
const removeProvider = async () => {
|
||||||
|
deleteProvider(provider.provider)
|
||||||
|
toast.success('Delete Provider', {
|
||||||
|
id: `delete-provider-${provider.provider}`,
|
||||||
|
description: `Provider ${provider.provider} has been permanently deleted.`,
|
||||||
|
})
|
||||||
|
setTimeout(() => {
|
||||||
|
router.navigate({
|
||||||
|
to: route.settings.providers,
|
||||||
|
params: {
|
||||||
|
providerName: providers[0].provider,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<CardItem
|
||||||
|
title="Delete Provider"
|
||||||
|
description="Delete this provider and all its models. This action cannot be undone."
|
||||||
|
actions={
|
||||||
|
<Dialog>
|
||||||
|
<DialogTrigger asChild>
|
||||||
|
<Button variant="destructive" size="sm">
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</DialogTrigger>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>Delete Provider: {provider.provider}</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
Are you sure you want to delete this provider? This action
|
||||||
|
cannot be undone.
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<DialogFooter className="mt-2">
|
||||||
|
<DialogClose asChild>
|
||||||
|
<Button variant="link" size="sm" className="hover:no-underline">
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</DialogClose>
|
||||||
|
<DialogClose asChild>
|
||||||
|
<Button
|
||||||
|
variant="destructive"
|
||||||
|
size="sm"
|
||||||
|
onClick={removeProvider}
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</DialogClose>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
export default DeleteProvider
|
||||||
@ -37,6 +37,7 @@ export const useChat = () => {
|
|||||||
const provider = useMemo(() => {
|
const provider = useMemo(() => {
|
||||||
return getProviderByName(selectedProvider)
|
return getProviderByName(selectedProvider)
|
||||||
}, [selectedProvider, getProviderByName])
|
}, [selectedProvider, getProviderByName])
|
||||||
|
|
||||||
const getCurrentThread = useCallback(async () => {
|
const getCurrentThread = useCallback(async () => {
|
||||||
let currentThread = retrieveThread()
|
let currentThread = retrieveThread()
|
||||||
if (!currentThread) {
|
if (!currentThread) {
|
||||||
|
|||||||
@ -13,6 +13,8 @@ type ModelProviderState = {
|
|||||||
providerName: string,
|
providerName: string,
|
||||||
modelName: string
|
modelName: string
|
||||||
) => Model | undefined
|
) => Model | undefined
|
||||||
|
addProvider: (provider: ModelProvider) => void
|
||||||
|
deleteProvider: (providerName: string) => void
|
||||||
deleteModel: (modelId: string) => void
|
deleteModel: (modelId: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,6 +116,18 @@ export const useModelProvider = create<ModelProviderState>()(
|
|||||||
}),
|
}),
|
||||||
}))
|
}))
|
||||||
},
|
},
|
||||||
|
addProvider: (provider: ModelProvider) => {
|
||||||
|
set((state) => ({
|
||||||
|
providers: [...state.providers, provider],
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
deleteProvider: (providerName: string) => {
|
||||||
|
set((state) => ({
|
||||||
|
providers: state.providers.filter(
|
||||||
|
(provider) => provider.provider !== providerName
|
||||||
|
),
|
||||||
|
}))
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
{
|
{
|
||||||
name: localStorageKey.modelProvider,
|
name: localStorageKey.modelProvider,
|
||||||
|
|||||||
@ -17,8 +17,6 @@ export function getProviderLogo(provider: string) {
|
|||||||
return '/images/model-provider/martian.svg'
|
return '/images/model-provider/martian.svg'
|
||||||
case 'openrouter':
|
case 'openrouter':
|
||||||
return '/images/model-provider/openRouter.svg'
|
return '/images/model-provider/openRouter.svg'
|
||||||
case 'openai':
|
|
||||||
return '/images/model-provider/openai.svg'
|
|
||||||
case 'groq':
|
case 'groq':
|
||||||
return '/images/model-provider/groq.svg'
|
return '/images/model-provider/groq.svg'
|
||||||
case 'cohere':
|
case 'cohere':
|
||||||
@ -32,7 +30,7 @@ export function getProviderLogo(provider: string) {
|
|||||||
case 'deepseek':
|
case 'deepseek':
|
||||||
return '/images/model-provider/deepseek.svg'
|
return '/images/model-provider/deepseek.svg'
|
||||||
default:
|
default:
|
||||||
return undefined
|
return '/images/model-provider/openai.svg'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,29 @@
|
|||||||
|
export const openAIProviderSettings = [
|
||||||
|
{
|
||||||
|
key: 'api-key',
|
||||||
|
title: 'API Key',
|
||||||
|
description:
|
||||||
|
"The OpenAI API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.",
|
||||||
|
controller_type: 'input',
|
||||||
|
controller_props: {
|
||||||
|
placeholder: 'Insert API Key',
|
||||||
|
value: '',
|
||||||
|
type: 'password',
|
||||||
|
input_actions: ['unobscure', 'copy'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'base-url',
|
||||||
|
title: 'Base URL',
|
||||||
|
description:
|
||||||
|
'The base endpoint to use. See the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create) for more information.',
|
||||||
|
controller_type: 'input',
|
||||||
|
controller_props: {
|
||||||
|
placeholder: 'https://api.openai.com/v1',
|
||||||
|
value: 'https://api.openai.com/v1',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
export const mockModelProvider = [
|
export const mockModelProvider = [
|
||||||
// {
|
// {
|
||||||
// active: true,
|
// active: true,
|
||||||
|
|||||||
@ -22,6 +22,8 @@ import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
|||||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||||
import { route } from '@/constants/routes'
|
import { route } from '@/constants/routes'
|
||||||
|
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
|
||||||
|
import { updateSettings } from '@/services/providers'
|
||||||
|
|
||||||
// as route.threadsDetail
|
// as route.threadsDetail
|
||||||
export const Route = createFileRoute('/settings/providers/$providerName')({
|
export const Route = createFileRoute('/settings/providers/$providerName')({
|
||||||
@ -165,6 +167,7 @@ function ProviderDetail() {
|
|||||||
) {
|
) {
|
||||||
updateObj.base_url = newValue
|
updateObj.base_url = newValue
|
||||||
}
|
}
|
||||||
|
updateSettings(providerName, updateObj.settings ?? [])
|
||||||
updateProvider(providerName, {
|
updateProvider(providerName, {
|
||||||
...provider,
|
...provider,
|
||||||
...updateObj,
|
...updateObj,
|
||||||
@ -214,6 +217,8 @@ function ProviderDetail() {
|
|||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
|
|
||||||
|
<DeleteProvider provider={provider} />
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
{/* Models */}
|
{/* Models */}
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import { models as providerModels } from 'token.js'
|
import { models as providerModels } from 'token.js'
|
||||||
import { mockModelProvider } from '@/mock/data'
|
import { mockModelProvider } from '@/mock/data'
|
||||||
import { EngineManager } from '@janhq/core'
|
import { EngineManager, SettingComponentProps } from '@janhq/core'
|
||||||
import { ModelCapabilities } from '@/types/models'
|
import { ModelCapabilities } from '@/types/models'
|
||||||
import { modelSettings } from '@/lib/predefined'
|
import { modelSettings } from '@/lib/predefined'
|
||||||
import { fetchModels } from './models'
|
import { fetchModels } from './models'
|
||||||
|
import { ExtensionManager } from '@/lib/extension'
|
||||||
|
|
||||||
export const getProviders = async (): Promise<ModelProvider[]> => {
|
export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||||
const builtinProviders = mockModelProvider.map((provider) => {
|
const builtinProviders = mockModelProvider.map((provider) => {
|
||||||
@ -84,3 +85,28 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
|||||||
|
|
||||||
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the settings of a provider extension.
|
||||||
|
* TODO: Later on we don't retrieve this using provider name
|
||||||
|
* @param providerName
|
||||||
|
* @param settings
|
||||||
|
*/
|
||||||
|
export const updateSettings = async (
|
||||||
|
providerName: string,
|
||||||
|
settings: ProviderSetting[]
|
||||||
|
): Promise<void> => {
|
||||||
|
const provider = providerName === 'llama.cpp' ? 'cortex' : providerName
|
||||||
|
return ExtensionManager.getInstance()
|
||||||
|
.getEngine(provider)
|
||||||
|
?.updateSettings(
|
||||||
|
settings.map((setting) => ({
|
||||||
|
...setting,
|
||||||
|
controllerProps: {
|
||||||
|
...setting.controller_props,
|
||||||
|
value: setting.controller_props.value ?? '',
|
||||||
|
},
|
||||||
|
controllerType: setting.controller_type,
|
||||||
|
})) as SettingComponentProps[]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user