chore: default context length to 2048 (#2746)

This commit is contained in:
NamH 2024-04-17 19:14:51 +07:00 committed by GitHub
parent a2cb1353cd
commit 95632788e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 34 deletions

View File

@ -32,4 +32,5 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void> abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model> abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model>
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData> abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
abstract getDefaultModel(): Promise<Model>
} }

View File

@ -27,7 +27,7 @@
"min": 0, "min": 0,
"max": 4096, "max": 4096,
"step": 128, "step": 128,
"value": 4096 "value": 2048
} }
} }
] ]

View File

@ -13,7 +13,7 @@
"created": 0, "created": 0,
"description": "User self import model", "description": "User self import model",
"settings": { "settings": {
"ctx_len": 4096, "ctx_len": 2048,
"embedding": false, "embedding": false,
"prompt_template": "{system_message}\n### Instruction: {prompt}\n### Response:", "prompt_template": "{system_message}\n### Instruction: {prompt}\n### Response:",
"llama_model_path": "N/A" "llama_model_path": "N/A"

View File

@ -551,7 +551,7 @@ export default class JanModelExtension extends ModelExtension {
return model return model
} }
private async getDefaultModel(): Promise<Model> { override async getDefaultModel(): Promise<Model> {
const defaultModel = DEFAULT_MODEL as Model const defaultModel = DEFAULT_MODEL as Model
return defaultModel return defaultModel
} }

View File

@ -46,6 +46,8 @@ export const removeDownloadedModelAtom = atom(
export const configuredModelsAtom = atom<Model[]>([]) export const configuredModelsAtom = atom<Model[]>([])
export const defaultModelAtom = atom<Model | undefined>(undefined)
/// TODO: move this part to another atom /// TODO: move this part to another atom
// store the paths of the models that are being imported // store the paths of the models that are being imported
export const importingModelsAtom = atom<ImportingModel[]>([]) export const importingModelsAtom = atom<ImportingModel[]>([])

View File

@ -13,25 +13,37 @@ import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
configuredModelsAtom, configuredModelsAtom,
defaultModelAtom,
downloadedModelsAtom, downloadedModelsAtom,
} from '@/helpers/atoms/Model.atom' } from '@/helpers/atoms/Model.atom'
const useModels = () => { const useModels = () => {
const setDownloadedModels = useSetAtom(downloadedModelsAtom) const setDownloadedModels = useSetAtom(downloadedModelsAtom)
const setConfiguredModels = useSetAtom(configuredModelsAtom) const setConfiguredModels = useSetAtom(configuredModelsAtom)
const setDefaultModel = useSetAtom(defaultModelAtom)
const getData = useCallback(() => { const getData = useCallback(() => {
const getDownloadedModels = async () => { const getDownloadedModels = async () => {
const models = await getLocalDownloadedModels() const models = await getLocalDownloadedModels()
setDownloadedModels(models) setDownloadedModels(models)
} }
const getConfiguredModels = async () => { const getConfiguredModels = async () => {
const models = await getLocalConfiguredModels() const models = await getLocalConfiguredModels()
setConfiguredModels(models) setConfiguredModels(models)
} }
getDownloadedModels()
getConfiguredModels() const getDefaultModel = async () => {
}, [setDownloadedModels, setConfiguredModels]) const defaultModel = await getLocalDefaultModel()
setDefaultModel(defaultModel)
}
Promise.all([
getDownloadedModels(),
getConfiguredModels(),
getDefaultModel(),
])
}, [setDownloadedModels, setConfiguredModels, setDefaultModel])
useEffect(() => { useEffect(() => {
// Try get data on mount // Try get data on mount
@ -46,6 +58,11 @@ const useModels = () => {
}, [getData]) }, [getData])
} }
const getLocalDefaultModel = async (): Promise<Model | undefined> =>
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.getDefaultModel()
const getLocalConfiguredModels = async (): Promise<Model[]> => const getLocalConfiguredModels = async (): Promise<Model[]> =>
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)

View File

@ -3,7 +3,6 @@ import { useCallback, useMemo } from 'react'
import { import {
DownloadState, DownloadState,
HuggingFaceRepoData, HuggingFaceRepoData,
InferenceEngine,
Model, Model,
Quantization, Quantization,
} from '@janhq/core' } from '@janhq/core'
@ -23,7 +22,10 @@ import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom' import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import {
defaultModelAtom,
downloadedModelsAtom,
} from '@/helpers/atoms/Model.atom'
type Props = { type Props = {
index: number index: number
@ -52,15 +54,15 @@ const ModelDownloadRow: React.FC<Props> = ({
const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
const defaultModel = useAtomValue(defaultModelAtom)
const model = useMemo(() => { const model = useMemo(() => {
const promptData: string = if (!defaultModel) {
(repoData.cardData['prompt_template'] as string) ?? return undefined
'{system_message}\n### Instruction: {prompt}\n### Response:' }
const model: Model = { const model: Model = {
object: 'model', ...defaultModel,
version: '1.0',
format: 'gguf',
sources: [ sources: [
{ {
url: downloadUrl, url: downloadUrl,
@ -70,38 +72,26 @@ const ModelDownloadRow: React.FC<Props> = ({
id: fileName, id: fileName,
name: fileName, name: fileName,
created: Date.now(), created: Date.now(),
description: 'User self import model',
settings: {
ctx_len: 4096,
embedding: false,
prompt_template: promptData,
llama_model_path: 'N/A',
},
parameters: {
temperature: 0.7,
top_p: 0.95,
stream: true,
max_tokens: 2048,
stop: ['<endofstring>'],
frequency_penalty: 0.7,
presence_penalty: 0,
},
metadata: { metadata: {
author: 'User', author: 'User',
tags: repoData.tags, tags: repoData.tags,
size: fileSize, size: fileSize,
}, },
engine: InferenceEngine.nitro,
} }
console.log('NamH model: ', JSON.stringify(model))
return model return model
}, [fileName, fileSize, repoData, downloadUrl]) }, [fileName, fileSize, repoData, downloadUrl, defaultModel])
const onAbortDownloadClick = useCallback(() => { const onAbortDownloadClick = useCallback(() => {
if (model) {
abortModelDownload(model) abortModelDownload(model)
}
}, [model, abortModelDownload]) }, [model, abortModelDownload])
const onDownloadClick = useCallback(async () => { const onDownloadClick = useCallback(async () => {
if (model) {
downloadModel(model) downloadModel(model)
}
}, [model, downloadModel]) }, [model, downloadModel])
const onUseModelClick = useCallback(async () => { const onUseModelClick = useCallback(async () => {
@ -120,6 +110,10 @@ const ModelDownloadRow: React.FC<Props> = ({
setHfImportingStage, setHfImportingStage,
]) ])
if (!model) {
return null
}
return ( return (
<div className="flex w-[662px] flex-row items-center justify-between space-x-1 rounded border border-border p-3"> <div className="flex w-[662px] flex-row items-center justify-between space-x-1 rounded border border-border p-3">
<div className="flex"> <div className="flex">