feat: reroute threads and messages requests to the backend
This commit is contained in:
parent
14737b7e31
commit
174f1c7dcb
@ -1,4 +1,10 @@
|
||||
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
|
||||
import {
|
||||
Thread,
|
||||
ThreadInterface,
|
||||
ThreadMessage,
|
||||
MessageInterface,
|
||||
ThreadAssistantInfo,
|
||||
} from '../../types'
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
|
||||
/**
|
||||
@ -17,10 +23,21 @@ export abstract class ConversationalExtension
|
||||
return ExtensionTypeEnum.Conversational
|
||||
}
|
||||
|
||||
abstract getThreads(): Promise<Thread[]>
|
||||
abstract saveThread(thread: Thread): Promise<void>
|
||||
abstract listThreads(): Promise<Thread[]>
|
||||
abstract createThread(thread: Partial<Thread>): Promise<Thread>
|
||||
abstract modifyThread(thread: Thread): Promise<void>
|
||||
abstract deleteThread(threadId: string): Promise<void>
|
||||
abstract addNewMessage(message: ThreadMessage): Promise<void>
|
||||
abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
|
||||
abstract getAllMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
abstract createMessage(message: Partial<ThreadMessage>): Promise<ThreadMessage>
|
||||
abstract deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||
abstract listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
abstract getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo>
|
||||
abstract createThreadAssistant(
|
||||
threadId: string,
|
||||
assistant: ThreadAssistantInfo
|
||||
): Promise<ThreadAssistantInfo>
|
||||
abstract modifyThreadAssistant(
|
||||
threadId: string,
|
||||
assistant: ThreadAssistantInfo
|
||||
): Promise<ThreadAssistantInfo>
|
||||
abstract modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@ import { events } from '../../events'
|
||||
import { BaseExtension } from '../../extension'
|
||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
||||
import { EngineManager } from './EngineManager'
|
||||
import { ModelManager } from '../../models/manager'
|
||||
|
||||
/**
|
||||
* Base AIEngine
|
||||
|
||||
@ -6,7 +6,6 @@ import {
|
||||
mkdirSync,
|
||||
appendFileSync,
|
||||
createWriteStream,
|
||||
rmdirSync,
|
||||
} from 'fs'
|
||||
import { JanApiRouteConfiguration, RouteConfiguration } from './configuration'
|
||||
import { join } from 'path'
|
||||
@ -126,7 +125,7 @@ export const createThread = async (thread: any) => {
|
||||
}
|
||||
}
|
||||
|
||||
const threadId = generateThreadId(thread.assistants[0].assistant_id)
|
||||
const threadId = generateThreadId(thread.assistants[0]?.assistant_id)
|
||||
try {
|
||||
const updatedThread = {
|
||||
...thread,
|
||||
@ -280,7 +279,7 @@ export const models = async (request: any, reply: any) => {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ""}`, {
|
||||
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ''}`, {
|
||||
method: request.method,
|
||||
headers: headers,
|
||||
body: JSON.stringify(request.body),
|
||||
|
||||
@ -11,20 +11,20 @@ export interface MessageInterface {
|
||||
* @param {ThreadMessage} message - The message to be added.
|
||||
* @returns {Promise<void>} A promise that resolves when the message has been added.
|
||||
*/
|
||||
addNewMessage(message: ThreadMessage): Promise<void>
|
||||
|
||||
/**
|
||||
* Writes an array of messages to a specific thread.
|
||||
* @param {string} threadId - The ID of the thread to write the messages to.
|
||||
* @param {ThreadMessage[]} messages - The array of messages to be written.
|
||||
* @returns {Promise<void>} A promise that resolves when the messages have been written.
|
||||
*/
|
||||
writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
|
||||
createMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||
|
||||
/**
|
||||
* Retrieves all messages from a specific thread.
|
||||
* @param {string} threadId - The ID of the thread to retrieve the messages from.
|
||||
* @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread.
|
||||
*/
|
||||
getAllMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
|
||||
/**
|
||||
* Deletes a specific message from a thread.
|
||||
* @param {string} threadId - The ID of the thread from which the message will be deleted.
|
||||
* @param {string} messageId - The ID of the message to be deleted.
|
||||
* @returns {Promise<void>} A promise that resolves when the message has been successfully deleted.
|
||||
*/
|
||||
deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||
}
|
||||
|
||||
@ -11,15 +11,23 @@ export interface ThreadInterface {
|
||||
* @abstract
|
||||
* @returns {Promise<Thread[]>} A promise that resolves to an array of threads.
|
||||
*/
|
||||
getThreads(): Promise<Thread[]>
|
||||
listThreads(): Promise<Thread[]>
|
||||
|
||||
/**
|
||||
* Saves a thread.
|
||||
* Create a thread.
|
||||
* @abstract
|
||||
* @param {Thread} thread - The thread to save.
|
||||
* @returns {Promise<void>} A promise that resolves when the thread is saved.
|
||||
*/
|
||||
saveThread(thread: Thread): Promise<void>
|
||||
createThread(thread: Thread): Promise<Thread>
|
||||
|
||||
/**
|
||||
* modify a thread.
|
||||
* @abstract
|
||||
* @param {Thread} thread - The thread to save.
|
||||
* @returns {Promise<void>} A promise that resolves when the thread is saved.
|
||||
*/
|
||||
modifyThread(thread: Thread): Promise<void>
|
||||
|
||||
/**
|
||||
* Deletes a thread.
|
||||
|
||||
@ -18,12 +18,14 @@
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"ts-loader": "^9.5.0",
|
||||
"webpack": "^5.88.2",
|
||||
"webpack-cli": "^5.1.4",
|
||||
"ts-loader": "^9.5.0"
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core"
|
||||
"@janhq/core": "file:../../core",
|
||||
"ky": "^1.7.2",
|
||||
"p-queue": "^8.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
|
||||
14
extensions/conversational-extension/src/@types/global.d.ts
vendored
Normal file
14
extensions/conversational-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
export {}
|
||||
declare global {
|
||||
declare const API_URL: string
|
||||
declare const SOCKET_URL: string
|
||||
|
||||
interface Core {
|
||||
api: APIFunctions
|
||||
events: EventEmitter
|
||||
}
|
||||
interface Window {
|
||||
core?: Core | undefined
|
||||
electronAPI?: any | undefined
|
||||
}
|
||||
}
|
||||
@ -1,408 +0,0 @@
|
||||
/**
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
fs: {
|
||||
existsSync: jest.fn(),
|
||||
mkdir: jest.fn(),
|
||||
writeFileSync: jest.fn(),
|
||||
readdirSync: jest.fn(),
|
||||
readFileSync: jest.fn(),
|
||||
appendFileSync: jest.fn(),
|
||||
rm: jest.fn(),
|
||||
writeBlob: jest.fn(),
|
||||
joinPath: jest.fn(),
|
||||
fileStat: jest.fn(),
|
||||
},
|
||||
joinPath: jest.fn(),
|
||||
ConversationalExtension: jest.fn(),
|
||||
}))
|
||||
|
||||
import { fs } from '@janhq/core'
|
||||
|
||||
import JSONConversationalExtension from '.'
|
||||
|
||||
describe('JSONConversationalExtension Tests', () => {
|
||||
let extension: JSONConversationalExtension
|
||||
|
||||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
extension = new JSONConversationalExtension()
|
||||
})
|
||||
|
||||
it('should create thread folder on load if it does not exist', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
|
||||
await extension.onLoad()
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
|
||||
})
|
||||
|
||||
it('should log message on unload', () => {
|
||||
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
|
||||
|
||||
extension.onUnload()
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'JSONConversationalExtension unloaded'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return sorted threads', async () => {
|
||||
jest
|
||||
.spyOn(extension, 'getValidThreadDirs')
|
||||
.mockResolvedValue(['dir1', 'dir2'])
|
||||
jest
|
||||
.spyOn(extension, 'readThread')
|
||||
.mockResolvedValueOnce({ updated: '2023-01-01' })
|
||||
.mockResolvedValueOnce({ updated: '2023-01-02' })
|
||||
|
||||
const threads = await extension.getThreads()
|
||||
|
||||
expect(threads).toEqual([
|
||||
{ updated: '2023-01-02' },
|
||||
{ updated: '2023-01-01' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should ignore broken threads', async () => {
|
||||
jest
|
||||
.spyOn(extension, 'getValidThreadDirs')
|
||||
.mockResolvedValue(['dir1', 'dir2'])
|
||||
jest
|
||||
.spyOn(extension, 'readThread')
|
||||
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
|
||||
.mockResolvedValueOnce('this_is_an_invalid_json_content')
|
||||
|
||||
const threads = await extension.getThreads()
|
||||
|
||||
expect(threads).toEqual([{ updated: '2023-01-01' }])
|
||||
})
|
||||
|
||||
it('should save thread', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const writeFileSyncSpy = jest
|
||||
.spyOn(fs, 'writeFileSync')
|
||||
.mockResolvedValue({})
|
||||
|
||||
const thread = { id: '1', updated: '2023-01-01' } as any
|
||||
await extension.saveThread(thread)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should delete thread', async () => {
|
||||
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
|
||||
|
||||
await extension.deleteThread('1')
|
||||
|
||||
expect(rmSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should add new message', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const appendFileSyncSpy = jest
|
||||
.spyOn(fs, 'appendFileSync')
|
||||
.mockResolvedValue({})
|
||||
|
||||
const message = {
|
||||
thread_id: '1',
|
||||
content: [{ type: 'text', text: { annotations: [] } }],
|
||||
} as any
|
||||
await extension.addNewMessage(message)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should store image', async () => {
|
||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
await extension.storeImage(
|
||||
'data:image/png;base64,abcd',
|
||||
'path/to/image.png'
|
||||
)
|
||||
|
||||
expect(writeBlobSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should store file', async () => {
|
||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
await extension.storeFile(
|
||||
'data:application/pdf;base64,abcd',
|
||||
'path/to/file.pdf'
|
||||
)
|
||||
|
||||
expect(writeBlobSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should write messages', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const writeFileSyncSpy = jest
|
||||
.spyOn(fs, 'writeFileSync')
|
||||
.mockResolvedValue({})
|
||||
|
||||
const messages = [{ id: '1', thread_id: '1', content: [] }] as any
|
||||
await extension.writeMessages('1', messages)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should get all messages on string response', async () => {
|
||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
||||
jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
|
||||
|
||||
const messages = await extension.getAllMessages('1')
|
||||
|
||||
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
|
||||
})
|
||||
|
||||
it('should get all messages on object response', async () => {
|
||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
||||
jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
|
||||
|
||||
const messages = await extension.getAllMessages('1')
|
||||
|
||||
expect(messages).toEqual([{ id: 1 }])
|
||||
})
|
||||
|
||||
it('get all messages return empty on error', async () => {
|
||||
jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
|
||||
|
||||
const messages = await extension.getAllMessages('1')
|
||||
|
||||
expect(messages).toEqual([])
|
||||
})
|
||||
|
||||
it('return empty messages on no messages file', async () => {
|
||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
|
||||
|
||||
const messages = await extension.getAllMessages('1')
|
||||
|
||||
expect(messages).toEqual([])
|
||||
})
|
||||
|
||||
it('should ignore error message', async () => {
|
||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
||||
jest
|
||||
.spyOn(fs, 'readFileSync')
|
||||
.mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
|
||||
|
||||
const messages = await extension.getAllMessages('1')
|
||||
|
||||
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
|
||||
})
|
||||
|
||||
it('should create thread folder on load if it does not exist', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
|
||||
await extension.onLoad()
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
|
||||
})
|
||||
|
||||
it('should log message on unload', () => {
|
||||
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
|
||||
|
||||
extension.onUnload()
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'JSONConversationalExtension unloaded'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return sorted threads', async () => {
|
||||
jest
|
||||
.spyOn(extension, 'getValidThreadDirs')
|
||||
.mockResolvedValue(['dir1', 'dir2'])
|
||||
jest
|
||||
.spyOn(extension, 'readThread')
|
||||
.mockResolvedValueOnce({ updated: '2023-01-01' })
|
||||
.mockResolvedValueOnce({ updated: '2023-01-02' })
|
||||
|
||||
const threads = await extension.getThreads()
|
||||
|
||||
expect(threads).toEqual([
|
||||
{ updated: '2023-01-02' },
|
||||
{ updated: '2023-01-01' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should ignore broken threads', async () => {
|
||||
jest
|
||||
.spyOn(extension, 'getValidThreadDirs')
|
||||
.mockResolvedValue(['dir1', 'dir2'])
|
||||
jest
|
||||
.spyOn(extension, 'readThread')
|
||||
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
|
||||
.mockResolvedValueOnce('this_is_an_invalid_json_content')
|
||||
|
||||
const threads = await extension.getThreads()
|
||||
|
||||
expect(threads).toEqual([{ updated: '2023-01-01' }])
|
||||
})
|
||||
|
||||
it('should save thread', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const writeFileSyncSpy = jest
|
||||
.spyOn(fs, 'writeFileSync')
|
||||
.mockResolvedValue({})
|
||||
|
||||
const thread = { id: '1', updated: '2023-01-01' } as any
|
||||
await extension.saveThread(thread)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should delete thread', async () => {
|
||||
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
|
||||
|
||||
await extension.deleteThread('1')
|
||||
|
||||
expect(rmSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should add new message', async () => {
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const appendFileSyncSpy = jest
|
||||
.spyOn(fs, 'appendFileSync')
|
||||
.mockResolvedValue({})
|
||||
|
||||
const message = {
|
||||
thread_id: '1',
|
||||
content: [{ type: 'text', text: { annotations: [] } }],
|
||||
} as any
|
||||
await extension.addNewMessage(message)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should add new image message', async () => {
|
||||
jest
|
||||
.spyOn(fs, 'existsSync')
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(false)
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(false)
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(true)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const appendFileSyncSpy = jest
|
||||
.spyOn(fs, 'appendFileSync')
|
||||
.mockResolvedValue({})
|
||||
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
const message = {
|
||||
thread_id: '1',
|
||||
content: [
|
||||
{ type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
|
||||
],
|
||||
} as any
|
||||
await extension.addNewMessage(message)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should add new pdf message', async () => {
|
||||
jest
|
||||
.spyOn(fs, 'existsSync')
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(false)
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(false)
|
||||
// @ts-ignore
|
||||
.mockResolvedValueOnce(true)
|
||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
||||
const appendFileSyncSpy = jest
|
||||
.spyOn(fs, 'appendFileSync')
|
||||
.mockResolvedValue({})
|
||||
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
const message = {
|
||||
thread_id: '1',
|
||||
content: [
|
||||
{ type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
|
||||
],
|
||||
} as any
|
||||
await extension.addNewMessage(message)
|
||||
|
||||
expect(mkdirSpy).toHaveBeenCalled()
|
||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should store image', async () => {
|
||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
await extension.storeImage(
|
||||
'data:image/png;base64,abcd',
|
||||
'path/to/image.png'
|
||||
)
|
||||
|
||||
expect(writeBlobSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should store file', async () => {
|
||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
||||
|
||||
await extension.storeFile(
|
||||
'data:application/pdf;base64,abcd',
|
||||
'path/to/file.pdf'
|
||||
)
|
||||
|
||||
expect(writeBlobSpy).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('test readThread', () => {
|
||||
let extension: JSONConversationalExtension
|
||||
|
||||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
extension = new JSONConversationalExtension()
|
||||
})
|
||||
|
||||
it('should read thread', async () => {
|
||||
jest
|
||||
.spyOn(fs, 'readFileSync')
|
||||
.mockResolvedValue(JSON.stringify({ id: '1' }))
|
||||
const thread = await extension.readThread('1')
|
||||
expect(thread).toEqual(`{"id":"1"}`)
|
||||
})
|
||||
|
||||
it('getValidThreadDirs should return valid thread directories', async () => {
|
||||
jest
|
||||
.spyOn(fs, 'readdirSync')
|
||||
.mockResolvedValueOnce(['1', '2', '3'])
|
||||
.mockResolvedValueOnce(['thread.json'])
|
||||
.mockResolvedValueOnce(['thread.json'])
|
||||
.mockResolvedValueOnce([])
|
||||
// @ts-ignore
|
||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
|
||||
jest.spyOn(fs, 'fileStat').mockResolvedValue({
|
||||
isDirectory: true,
|
||||
} as any)
|
||||
const validThreadDirs = await extension.getValidThreadDirs()
|
||||
expect(validThreadDirs).toEqual(['1', '2'])
|
||||
})
|
||||
})
|
||||
@ -1,90 +1,71 @@
|
||||
import {
|
||||
fs,
|
||||
joinPath,
|
||||
ConversationalExtension,
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
import { safelyParseJSON } from './jsonUtil'
|
||||
import ky from 'ky'
|
||||
import PQueue from 'p-queue'
|
||||
|
||||
type ThreadList = {
|
||||
data: Thread[]
|
||||
}
|
||||
|
||||
type MessageList = {
|
||||
data: ThreadMessage[]
|
||||
}
|
||||
|
||||
/**
|
||||
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
||||
* functionality for managing threads.
|
||||
*/
|
||||
export default class JSONConversationalExtension extends ConversationalExtension {
|
||||
private static readonly _threadFolder = 'file://threads'
|
||||
private static readonly _threadInfoFileName = 'thread.json'
|
||||
private static readonly _threadMessagesFileName = 'messages.jsonl'
|
||||
queue = new PQueue({ concurrency: 1 })
|
||||
|
||||
/**
|
||||
* Called when the extension is loaded.
|
||||
*/
|
||||
async onLoad() {
|
||||
if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) {
|
||||
await fs.mkdir(JSONConversationalExtension._threadFolder)
|
||||
}
|
||||
this.queue.add(() => this.healthz())
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when the extension is unloaded.
|
||||
*/
|
||||
onUnload() {
|
||||
console.debug('JSONConversationalExtension unloaded')
|
||||
}
|
||||
onUnload() {}
|
||||
|
||||
/**
|
||||
* Returns a Promise that resolves to an array of Conversation objects.
|
||||
*/
|
||||
async getThreads(): Promise<Thread[]> {
|
||||
try {
|
||||
const threadDirs = await this.getValidThreadDirs()
|
||||
|
||||
const promises = threadDirs.map((dirName) => this.readThread(dirName))
|
||||
const promiseResults = await Promise.allSettled(promises)
|
||||
const convos = promiseResults
|
||||
.map((result) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
return typeof result.value === 'object'
|
||||
? result.value
|
||||
: safelyParseJSON(result.value)
|
||||
}
|
||||
return undefined
|
||||
})
|
||||
.filter((convo) => !!convo)
|
||||
convos.sort(
|
||||
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
|
||||
)
|
||||
|
||||
return convos
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
return []
|
||||
}
|
||||
async listThreads(): Promise<Thread[]> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.get(`${API_URL}/v1/threads`)
|
||||
.json<ThreadList>()
|
||||
.then((e) => e.data)
|
||||
) as Promise<Thread[]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a Thread object to a json file.
|
||||
* @param thread The Thread object to save.
|
||||
*/
|
||||
async saveThread(thread: Thread): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
thread.id,
|
||||
])
|
||||
const threadJsonPath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadInfoFileName,
|
||||
])
|
||||
if (!(await fs.existsSync(threadDirPath))) {
|
||||
await fs.mkdir(threadDirPath)
|
||||
}
|
||||
async createThread(thread: Thread): Promise<Thread> {
|
||||
return this.queue.add(() =>
|
||||
ky.post(`${API_URL}/v1/threads`, { json: thread }).json<Thread>()
|
||||
) as Promise<Thread>
|
||||
}
|
||||
|
||||
await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2))
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
Promise.reject(err)
|
||||
}
|
||||
/**
|
||||
* Saves a Thread object to a json file.
|
||||
* @param thread The Thread object to save.
|
||||
*/
|
||||
async modifyThread(thread: Thread): Promise<void> {
|
||||
return this.queue
|
||||
.add(() =>
|
||||
ky.post(`${API_URL}/v1/threads/${thread.id}`, { json: thread })
|
||||
)
|
||||
.then()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -92,189 +73,126 @@ export default class JSONConversationalExtension extends ConversationalExtension
|
||||
* @param threadId The ID of the thread to delete.
|
||||
*/
|
||||
async deleteThread(threadId: string): Promise<void> {
|
||||
const path = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
`${threadId}`,
|
||||
])
|
||||
try {
|
||||
await fs.rm(path)
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
return this.queue
|
||||
.add(() => ky.delete(`${API_URL}/v1/threads/${threadId}`))
|
||||
.then()
|
||||
}
|
||||
|
||||
async addNewMessage(message: ThreadMessage): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
message.thread_id,
|
||||
])
|
||||
const threadMessagePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
|
||||
|
||||
if (message.content[0]?.type === 'image') {
|
||||
const filesPath = await joinPath([threadDirPath, 'files'])
|
||||
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
|
||||
|
||||
const imagePath = await joinPath([filesPath, `${message.id}.png`])
|
||||
const base64 = message.content[0].text.annotations[0]
|
||||
await this.storeImage(base64, imagePath)
|
||||
if ((await fs.existsSync(imagePath)) && message.content?.length) {
|
||||
// Use file path instead of blob
|
||||
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
|
||||
}
|
||||
}
|
||||
|
||||
if (message.content[0]?.type === 'pdf') {
|
||||
const filesPath = await joinPath([threadDirPath, 'files'])
|
||||
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
|
||||
|
||||
const filePath = await joinPath([filesPath, `${message.id}.pdf`])
|
||||
const blob = message.content[0].text.annotations[0]
|
||||
await this.storeFile(blob, filePath)
|
||||
|
||||
if ((await fs.existsSync(filePath)) && message.content?.length) {
|
||||
// Use file path instead of blob
|
||||
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
|
||||
}
|
||||
}
|
||||
await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n')
|
||||
Promise.resolve()
|
||||
} catch (err) {
|
||||
Promise.reject(err)
|
||||
}
|
||||
/**
|
||||
* Adds a new message to a specified thread.
|
||||
* @param message The ThreadMessage object to be added.
|
||||
* @returns A Promise that resolves when the message has been added.
|
||||
*/
|
||||
async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.post(`${API_URL}/v1/threads/${message.thread_id}/messages`, {
|
||||
json: message,
|
||||
})
|
||||
.json<ThreadMessage>()
|
||||
) as Promise<ThreadMessage>
|
||||
}
|
||||
|
||||
async storeImage(base64: string, filePath: string): Promise<void> {
|
||||
const base64Data = base64.replace(/^data:image\/\w+;base64,/, '')
|
||||
|
||||
try {
|
||||
await fs.writeBlob(filePath, base64Data)
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
/**
|
||||
* Modifies a message in a thread.
|
||||
* @param message
|
||||
* @returns
|
||||
*/
|
||||
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.post(
|
||||
`${API_URL}/v1/threads/${message.thread_id}/messages/${message.id}`,
|
||||
{
|
||||
json: message,
|
||||
}
|
||||
)
|
||||
.json<ThreadMessage>()
|
||||
) as Promise<ThreadMessage>
|
||||
}
|
||||
|
||||
async storeFile(base64: string, filePath: string): Promise<void> {
|
||||
const base64Data = base64.replace(/^data:application\/pdf;base64,/, '')
|
||||
try {
|
||||
await fs.writeBlob(filePath, base64Data)
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
/**
|
||||
* Deletes a specific message from a thread.
|
||||
* @param threadId The ID of the thread containing the message.
|
||||
* @param messageId The ID of the message to be deleted.
|
||||
* @returns A Promise that resolves when the message has been successfully deleted.
|
||||
*/
|
||||
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
||||
return this.queue
|
||||
.add(() =>
|
||||
ky.delete(`${API_URL}/v1/threads/${threadId}/messages/${messageId}`)
|
||||
)
|
||||
.then()
|
||||
}
|
||||
|
||||
async writeMessages(
|
||||
/**
|
||||
* Retrieves all messages for a specified thread.
|
||||
* @param threadId The ID of the thread to get messages from.
|
||||
* @returns A Promise that resolves to an array of ThreadMessage objects.
|
||||
*/
|
||||
async listMessages(threadId: string): Promise<ThreadMessage[]> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.get(`${API_URL}/v1/threads/${threadId}/messages?order=asc`)
|
||||
.json<MessageList>()
|
||||
.then((e) => e.data)
|
||||
) as Promise<ThreadMessage[]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the assistant information for a specified thread.
|
||||
* @param threadId The ID of the thread for which to retrieve assistant information.
|
||||
* @returns A Promise that resolves to a ThreadAssistantInfo object containing
|
||||
* the details of the assistant associated with the specified thread.
|
||||
*/
|
||||
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
|
||||
return this.queue.add(() =>
|
||||
ky.get(`${API_URL}/v1/assistants/${threadId}`).json<ThreadAssistantInfo>()
|
||||
) as Promise<ThreadAssistantInfo>
|
||||
}
|
||||
/**
|
||||
* Creates a new assistant for the specified thread.
|
||||
* @param threadId The ID of the thread for which the assistant is being created.
|
||||
* @param assistant The information about the assistant to be created.
|
||||
* @returns A Promise that resolves to the newly created ThreadAssistantInfo object.
|
||||
*/
|
||||
async createThreadAssistant(
|
||||
threadId: string,
|
||||
messages: ThreadMessage[]
|
||||
): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
threadId,
|
||||
])
|
||||
const threadMessagePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
|
||||
await fs.writeFileSync(
|
||||
threadMessagePath,
|
||||
messages.map((msg) => JSON.stringify(msg)).join('\n') +
|
||||
(messages.length ? '\n' : '')
|
||||
)
|
||||
Promise.resolve()
|
||||
} catch (err) {
|
||||
Promise.reject(err)
|
||||
}
|
||||
assistant: ThreadAssistantInfo
|
||||
): Promise<ThreadAssistantInfo> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.post(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
|
||||
.json<ThreadAssistantInfo>()
|
||||
) as Promise<ThreadAssistantInfo>
|
||||
}
|
||||
|
||||
/**
|
||||
* A promise builder for reading a thread from a file.
|
||||
* @param threadDirName the thread dir we are reading from.
|
||||
* @returns data of the thread
|
||||
* Modifies an existing assistant for the specified thread.
|
||||
* @param threadId The ID of the thread for which the assistant is being modified.
|
||||
* @param assistant The updated information for the assistant.
|
||||
* @returns A Promise that resolves to the updated ThreadAssistantInfo object.
|
||||
*/
|
||||
async readThread(threadDirName: string): Promise<any> {
|
||||
return fs.readFileSync(
|
||||
await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
threadDirName,
|
||||
JSONConversationalExtension._threadInfoFileName,
|
||||
]),
|
||||
'utf-8'
|
||||
)
|
||||
async modifyThreadAssistant(
|
||||
threadId: string,
|
||||
assistant: ThreadAssistantInfo
|
||||
): Promise<ThreadAssistantInfo> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.patch(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
|
||||
.json<ThreadAssistantInfo>()
|
||||
) as Promise<ThreadAssistantInfo>
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Promise that resolves to an array of thread directories.
|
||||
* @private
|
||||
* Do health check on cortex.cpp
|
||||
* @returns
|
||||
*/
|
||||
async getValidThreadDirs(): Promise<string[]> {
|
||||
const fileInsideThread: string[] = await fs.readdirSync(
|
||||
JSONConversationalExtension._threadFolder
|
||||
)
|
||||
|
||||
const threadDirs: string[] = []
|
||||
for (let i = 0; i < fileInsideThread.length; i++) {
|
||||
const path = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
fileInsideThread[i],
|
||||
])
|
||||
if (!(await fs.fileStat(path))?.isDirectory) continue
|
||||
|
||||
const isHavingThreadInfo = (await fs.readdirSync(path)).includes(
|
||||
JSONConversationalExtension._threadInfoFileName
|
||||
)
|
||||
if (!isHavingThreadInfo) {
|
||||
console.debug(`Ignore ${path} because it does not have thread info`)
|
||||
continue
|
||||
}
|
||||
|
||||
threadDirs.push(fileInsideThread[i])
|
||||
}
|
||||
return threadDirs
|
||||
}
|
||||
|
||||
async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
|
||||
try {
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._threadFolder,
|
||||
threadId,
|
||||
])
|
||||
|
||||
const files: string[] = await fs.readdirSync(threadDirPath)
|
||||
if (
|
||||
!files.includes(JSONConversationalExtension._threadMessagesFileName)
|
||||
) {
|
||||
console.debug(`${threadDirPath} not contains message file`)
|
||||
return []
|
||||
}
|
||||
|
||||
const messageFilePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
|
||||
let readResult = await fs.readFileSync(messageFilePath, 'utf-8')
|
||||
|
||||
if (typeof readResult === 'object') {
|
||||
readResult = JSON.stringify(readResult)
|
||||
}
|
||||
|
||||
const result = readResult.split('\n').filter((line) => line !== '')
|
||||
|
||||
const messages: ThreadMessage[] = []
|
||||
result.forEach((line: string) => {
|
||||
const message = safelyParseJSON(line)
|
||||
if (message) messages.push(safelyParseJSON(line))
|
||||
healthz(): Promise<void> {
|
||||
return ky
|
||||
.get(`${API_URL}/healthz`, {
|
||||
retry: { limit: 20, delay: () => 500, methods: ['get'] },
|
||||
})
|
||||
return messages
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
return []
|
||||
}
|
||||
.then(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
// Note about performance
|
||||
// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
|
||||
// v8 4.5 and above can optimise try/catch
|
||||
export function safelyParseJSON(json) {
|
||||
// This function cannot be optimised, it's best to
|
||||
// keep it small!
|
||||
var parsed
|
||||
try {
|
||||
parsed = JSON.parse(json)
|
||||
} catch (e) {
|
||||
return undefined
|
||||
}
|
||||
return parsed // Could be undefined!
|
||||
}
|
||||
@ -17,7 +17,12 @@ module.exports = {
|
||||
filename: 'index.js', // Adjust the output file name as needed
|
||||
library: { type: 'module' }, // Specify ESM output format
|
||||
},
|
||||
plugins: [new webpack.DefinePlugin({})],
|
||||
plugins: [
|
||||
new webpack.DefinePlugin({
|
||||
API_URL: JSON.stringify('http://127.0.0.1:39291'),
|
||||
SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
|
||||
}),
|
||||
],
|
||||
resolve: {
|
||||
extensions: ['.ts', '.js'],
|
||||
},
|
||||
|
||||
@ -18,14 +18,14 @@ import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
||||
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
||||
const setMainState = useSetAtom(mainViewStateAtom)
|
||||
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const defaultDesc = () => {
|
||||
return (
|
||||
|
||||
@ -46,6 +46,7 @@ import {
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
||||
import {
|
||||
configuredModelsAtom,
|
||||
@ -75,6 +76,7 @@ const ModelDropdown = ({
|
||||
const [searchText, setSearchText] = useState('')
|
||||
const [open, setOpen] = useState(false)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const downloadingModels = useAtomValue(getDownloadingModelAtom)
|
||||
const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
|
||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||
@ -151,17 +153,24 @@ const ModelDropdown = ({
|
||||
|
||||
useEffect(() => {
|
||||
if (!activeThread) return
|
||||
const modelId = activeThread?.assistants?.[0]?.model?.id
|
||||
const modelId = activeAssistant?.model?.id
|
||||
|
||||
let model = downloadedModels.find((model) => model.id === modelId)
|
||||
if (!model) {
|
||||
model = recommendedModel
|
||||
}
|
||||
setSelectedModel(model)
|
||||
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
|
||||
}, [
|
||||
recommendedModel,
|
||||
activeThread,
|
||||
downloadedModels,
|
||||
setSelectedModel,
|
||||
activeAssistant?.model?.id,
|
||||
])
|
||||
|
||||
const onClickModelItem = useCallback(
|
||||
async (modelId: string) => {
|
||||
if (!activeAssistant) return
|
||||
const model = downloadedModels.find((m) => m.id === modelId)
|
||||
setSelectedModel(model)
|
||||
setOpen(false)
|
||||
@ -172,14 +181,14 @@ const ModelDropdown = ({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
...activeAssistant,
|
||||
tools: [
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: isModelSupportRagAndTools(model as Model),
|
||||
settings: {
|
||||
...(activeThread.assistants[0].tools &&
|
||||
activeThread.assistants[0].tools[0]?.settings),
|
||||
...(activeAssistant.tools &&
|
||||
activeAssistant.tools[0]?.settings),
|
||||
},
|
||||
},
|
||||
],
|
||||
@ -215,13 +224,14 @@ const ModelDropdown = ({
|
||||
}
|
||||
},
|
||||
[
|
||||
activeAssistant,
|
||||
downloadedModels,
|
||||
activeThread,
|
||||
setSelectedModel,
|
||||
activeThread,
|
||||
updateThreadMetadata,
|
||||
isModelSupportRagAndTools,
|
||||
setThreadModelParams,
|
||||
updateModelParameter,
|
||||
updateThreadMetadata,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { Fragment, useCallback, useEffect, useRef } from 'react'
|
||||
import { Fragment, use, useCallback, useEffect, useRef } from 'react'
|
||||
|
||||
import {
|
||||
ChatCompletionMessage,
|
||||
@ -31,6 +31,7 @@ import {
|
||||
addNewMessageAtom,
|
||||
updateMessageAtom,
|
||||
tokenSpeedAtom,
|
||||
deleteMessageAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
@ -49,6 +50,7 @@ export default function ModelHandler() {
|
||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||
const updateMessage = useSetAtom(updateMessageAtom)
|
||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||
const deleteMessage = useSetAtom(deleteMessageAtom)
|
||||
const activeModel = useAtomValue(activeModelAtom)
|
||||
const setActiveModel = useSetAtom(activeModelAtom)
|
||||
const setStateModel = useSetAtom(stateModelAtom)
|
||||
@ -86,7 +88,7 @@ export default function ModelHandler() {
|
||||
}, [activeModelParams])
|
||||
|
||||
const onNewMessageResponse = useCallback(
|
||||
(message: ThreadMessage) => {
|
||||
async (message: ThreadMessage) => {
|
||||
if (message.type === MessageRequestType.Thread) {
|
||||
addNewMessage(message)
|
||||
}
|
||||
@ -154,12 +156,15 @@ export default function ModelHandler() {
|
||||
...thread,
|
||||
|
||||
title: cleanedMessageContent,
|
||||
metadata: thread.metadata,
|
||||
metadata: {
|
||||
...thread.metadata,
|
||||
title: cleanedMessageContent,
|
||||
},
|
||||
}
|
||||
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread({
|
||||
?.modifyThread({
|
||||
...updatedThread,
|
||||
})
|
||||
.then(() => {
|
||||
@ -233,7 +238,9 @@ export default function ModelHandler() {
|
||||
|
||||
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
||||
if (!thread) return
|
||||
|
||||
const messageContent = message.content[0]?.text?.value
|
||||
|
||||
const metadata = {
|
||||
...thread.metadata,
|
||||
...(messageContent && { lastMessage: messageContent }),
|
||||
@ -246,15 +253,19 @@ export default function ModelHandler() {
|
||||
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread({
|
||||
?.modifyThread({
|
||||
...thread,
|
||||
metadata,
|
||||
})
|
||||
|
||||
// If this is not the summary of the Thread, don't need to add it to the Thread
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.addNewMessage(message)
|
||||
;(async () => {
|
||||
const updatedMessage = await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.createMessage(message)
|
||||
if (updatedMessage) {
|
||||
deleteMessage(message.id)
|
||||
addNewMessage(updatedMessage)
|
||||
}
|
||||
})()
|
||||
|
||||
// Attempt to generate the title of the Thread when needed
|
||||
generateThreadTitle(message, thread)
|
||||
@ -279,7 +290,9 @@ export default function ModelHandler() {
|
||||
|
||||
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
||||
// If this is the first ever prompt in the thread
|
||||
if (thread.title?.trim() !== defaultThreadTitle) {
|
||||
if (
|
||||
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -292,11 +305,14 @@ export default function ModelHandler() {
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
||||
metadata: thread.metadata,
|
||||
metadata: {
|
||||
...thread.metadata,
|
||||
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
||||
},
|
||||
}
|
||||
return extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread({
|
||||
?.modifyThread({
|
||||
...updatedThread,
|
||||
})
|
||||
.then(() => {
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
import { Assistant } from '@janhq/core'
|
||||
import { Assistant, ThreadAssistantInfo } from '@janhq/core'
|
||||
import { atom } from 'jotai'
|
||||
import { atomWithStorage } from 'jotai/utils'
|
||||
|
||||
export const assistantsAtom = atom<Assistant[]>([])
|
||||
|
||||
/**
|
||||
* Get the current active assistant
|
||||
*/
|
||||
export const activeAssistantAtom = atomWithStorage<
|
||||
ThreadAssistantInfo | undefined
|
||||
>('activeAssistant', undefined, undefined, { getOnInit: true })
|
||||
|
||||
@ -6,6 +6,8 @@ import {
|
||||
} from '@janhq/core'
|
||||
import { atom } from 'jotai'
|
||||
|
||||
import { atomWithStorage } from 'jotai/utils'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
updateThreadStateLastMessageAtom,
|
||||
@ -13,15 +15,23 @@ import {
|
||||
|
||||
import { TokenSpeed } from '@/types/token'
|
||||
|
||||
const CHAT_MESSAGE_NAME = 'chatMessages'
|
||||
/**
|
||||
* Stores all chat messages for all threads
|
||||
*/
|
||||
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
||||
export const chatMessages = atomWithStorage<Record<string, ThreadMessage[]>>(
|
||||
CHAT_MESSAGE_NAME,
|
||||
{},
|
||||
undefined,
|
||||
{ getOnInit: true }
|
||||
)
|
||||
|
||||
/**
|
||||
* Stores the status of the messages load for each thread
|
||||
*/
|
||||
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
||||
export const readyThreadsMessagesAtom = atomWithStorage<
|
||||
Record<string, boolean>
|
||||
>('currentThreadMessages', {}, undefined, { getOnInit: true })
|
||||
|
||||
/**
|
||||
* Store the token speed for current message
|
||||
@ -34,6 +44,7 @@ export const getCurrentChatMessagesAtom = atom<ThreadMessage[]>((get) => {
|
||||
const activeThreadId = get(getActiveThreadIdAtom)
|
||||
if (!activeThreadId) return []
|
||||
const messages = get(chatMessages)[activeThreadId]
|
||||
if (!Array.isArray(messages)) return []
|
||||
return messages ?? []
|
||||
})
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import { toaster } from '@/containers/Toast'
|
||||
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
||||
|
||||
import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
@ -34,6 +35,7 @@ export function useActiveModel() {
|
||||
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
||||
const pendingModelLoad = useRef(false)
|
||||
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const downloadedModelsRef = useRef<Model[]>([])
|
||||
|
||||
@ -79,12 +81,12 @@ export function useActiveModel() {
|
||||
}
|
||||
|
||||
/// Apply thread model settings
|
||||
if (activeThread?.assistants[0]?.model.id === modelId) {
|
||||
if (activeAssistant?.model.id === modelId) {
|
||||
model = {
|
||||
...model,
|
||||
settings: {
|
||||
...model.settings,
|
||||
...activeThread.assistants[0].model.settings,
|
||||
...activeAssistant?.model.settings,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import { useCallback } from 'react'
|
||||
|
||||
import {
|
||||
Assistant,
|
||||
ConversationalExtension,
|
||||
ExtensionTypeEnum,
|
||||
Thread,
|
||||
@ -9,16 +8,17 @@ import {
|
||||
ThreadState,
|
||||
AssistantTool,
|
||||
Model,
|
||||
Assistant,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { useDebouncedCallback } from 'use-debounce'
|
||||
|
||||
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
|
||||
import { fileUploadAtom } from '@/containers/Providers/Jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
|
||||
import { generateThreadId } from '@/utils/thread'
|
||||
|
||||
import { useActiveModel } from './useActiveModel'
|
||||
import useRecommendedModel from './useRecommendedModel'
|
||||
|
||||
@ -27,6 +27,7 @@ import useSetActiveThread from './useSetActiveThread'
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
threadsAtom,
|
||||
@ -34,7 +35,6 @@ import {
|
||||
updateThreadAtom,
|
||||
setThreadModelParamsAtom,
|
||||
isGeneratingResponseAtom,
|
||||
activeThreadAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||
@ -64,7 +64,7 @@ export const useCreateNewThread = () => {
|
||||
const copyOverInstructionEnabled = useAtomValue(
|
||||
copyOverInstructionEnabledAtom
|
||||
)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
|
||||
|
||||
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||
@ -75,7 +75,7 @@ export const useCreateNewThread = () => {
|
||||
const { stopInference } = useActiveModel()
|
||||
|
||||
const requestCreateNewThread = async (
|
||||
assistant: Assistant,
|
||||
assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant,
|
||||
model?: Model | undefined
|
||||
) => {
|
||||
// Stop generating if any
|
||||
@ -124,7 +124,7 @@ export const useCreateNewThread = () => {
|
||||
const createdAt = Date.now()
|
||||
let instructions: string | undefined = assistant.instructions
|
||||
if (copyOverInstructionEnabled) {
|
||||
instructions = activeThread?.assistants[0]?.instructions ?? undefined
|
||||
instructions = activeAssistant?.instructions ?? undefined
|
||||
}
|
||||
const assistantInfo: ThreadAssistantInfo = {
|
||||
assistant_id: assistant.id,
|
||||
@ -139,46 +139,95 @@ export const useCreateNewThread = () => {
|
||||
instructions,
|
||||
}
|
||||
|
||||
const threadId = generateThreadId(assistant.id)
|
||||
const thread: Thread = {
|
||||
id: threadId,
|
||||
const thread: Partial<Thread> = {
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
assistants: [assistantInfo],
|
||||
created: createdAt,
|
||||
updated: createdAt,
|
||||
metadata: {
|
||||
title: 'New Thread',
|
||||
},
|
||||
}
|
||||
|
||||
// add the new thread on top of the thread list to the state
|
||||
//TODO: Why do we have thread list then thread states? Should combine them
|
||||
createNewThread(thread)
|
||||
try {
|
||||
const createdThread = await persistNewThread(thread, assistantInfo)
|
||||
if (!createdThread) throw 'Thread creation failed'
|
||||
createNewThread(createdThread)
|
||||
|
||||
setSelectedModel(defaultModel)
|
||||
setThreadModelParams(thread.id, {
|
||||
...defaultModel?.settings,
|
||||
...defaultModel?.parameters,
|
||||
...overriddenSettings,
|
||||
})
|
||||
setSelectedModel(defaultModel)
|
||||
setThreadModelParams(createdThread.id, {
|
||||
...defaultModel?.settings,
|
||||
...defaultModel?.parameters,
|
||||
...overriddenSettings,
|
||||
})
|
||||
|
||||
// Delete the file upload state
|
||||
setFileUpload([])
|
||||
// Update thread metadata
|
||||
await updateThreadMetadata(thread)
|
||||
|
||||
setActiveThread(thread)
|
||||
// Delete the file upload state
|
||||
setFileUpload([])
|
||||
setActiveThread(createdThread)
|
||||
} catch (ex) {
|
||||
return toaster({
|
||||
title: 'Thread created failed.',
|
||||
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
|
||||
type: 'error',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const updateThreadExtension = (thread: Thread) => {
|
||||
return extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.modifyThread(thread)
|
||||
}
|
||||
|
||||
const updateAssistantExtension = (
|
||||
threadId: string,
|
||||
assistant: ThreadAssistantInfo
|
||||
) => {
|
||||
return extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.modifyThreadAssistant(threadId, assistant)
|
||||
}
|
||||
|
||||
const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300)
|
||||
const updateAssistantCallback = useDebouncedCallback(
|
||||
updateAssistantExtension,
|
||||
300
|
||||
)
|
||||
|
||||
const updateThreadMetadata = useCallback(
|
||||
async (thread: Thread) => {
|
||||
updateThread(thread)
|
||||
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread(thread)
|
||||
setActiveAssistant(thread.assistants[0])
|
||||
updateThreadCallback(thread)
|
||||
updateAssistantCallback(thread.id, thread.assistants[0])
|
||||
},
|
||||
[updateThread]
|
||||
[
|
||||
updateThread,
|
||||
setActiveAssistant,
|
||||
updateThreadCallback,
|
||||
updateAssistantCallback,
|
||||
]
|
||||
)
|
||||
|
||||
const persistNewThread = async (
|
||||
thread: Partial<Thread>,
|
||||
assistantInfo: ThreadAssistantInfo
|
||||
): Promise<Thread | undefined> => {
|
||||
return await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.createThread(thread)
|
||||
.then(async (thread) => {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.createThreadAssistant(thread.id, assistantInfo)
|
||||
return thread
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
requestCreateNewThread,
|
||||
updateThreadMetadata,
|
||||
|
||||
@ -1,13 +1,6 @@
|
||||
import { useCallback } from 'react'
|
||||
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ExtensionTypeEnum,
|
||||
ConversationalExtension,
|
||||
fs,
|
||||
joinPath,
|
||||
Thread,
|
||||
} from '@janhq/core'
|
||||
import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core'
|
||||
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
@ -15,89 +8,63 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
|
||||
import { useCreateNewThread } from './useCreateNewThread'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
|
||||
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import {
|
||||
chatMessages,
|
||||
cleanChatMessageAtom as cleanChatMessagesAtom,
|
||||
deleteChatMessageAtom as deleteChatMessagesAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
threadsAtom,
|
||||
setActiveThreadIdAtom,
|
||||
deleteThreadStateAtom,
|
||||
updateThreadStateLastMessageAtom,
|
||||
updateThreadAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useDeleteThread() {
|
||||
const [threads, setThreads] = useAtom(threadsAtom)
|
||||
const messages = useAtomValue(chatMessages)
|
||||
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
||||
const { requestCreateNewThread } = useCreateNewThread()
|
||||
const assistants = useAtomValue(assistantsAtom)
|
||||
const models = useAtomValue(downloadedModelsAtom)
|
||||
|
||||
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
|
||||
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
|
||||
|
||||
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
|
||||
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
|
||||
const cleanThread = useCallback(
|
||||
async (threadId: string) => {
|
||||
cleanMessages(threadId)
|
||||
const thread = threads.find((c) => c.id === threadId)
|
||||
if (!thread) return
|
||||
const assistantInfo = await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.getThreadAssistant(thread.id)
|
||||
|
||||
const updatedMessages = (messages[threadId] ?? []).filter(
|
||||
(msg) => msg.role === ChatCompletionRole.System
|
||||
if (!assistantInfo) return
|
||||
const model = models.find((c) => c.id === assistantInfo?.model?.id)
|
||||
|
||||
requestCreateNewThread(
|
||||
{
|
||||
...assistantInfo,
|
||||
id: assistants[0].id,
|
||||
name: assistants[0].name,
|
||||
},
|
||||
model
|
||||
? {
|
||||
...model,
|
||||
parameters: assistantInfo?.model?.parameters ?? {},
|
||||
settings: assistantInfo?.model?.settings ?? {},
|
||||
}
|
||||
: undefined
|
||||
)
|
||||
|
||||
// remove files
|
||||
try {
|
||||
const threadFolderPath = await joinPath([
|
||||
janDataFolderPath,
|
||||
'threads',
|
||||
threadId,
|
||||
])
|
||||
const threadFilesPath = await joinPath([threadFolderPath, 'files'])
|
||||
const threadMemoryPath = await joinPath([threadFolderPath, 'memory'])
|
||||
await fs.rm(threadFilesPath)
|
||||
await fs.rm(threadMemoryPath)
|
||||
} catch (err) {
|
||||
console.warn('Error deleting thread files', err)
|
||||
}
|
||||
|
||||
// Delete this thread
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.writeMessages(threadId, updatedMessages)
|
||||
|
||||
thread.metadata = {
|
||||
...thread.metadata,
|
||||
}
|
||||
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
title: 'New Thread',
|
||||
metadata: { ...thread.metadata, lastMessage: undefined },
|
||||
}
|
||||
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
updateThreadLastMessage(threadId, undefined)
|
||||
updateThread(updatedThread)
|
||||
?.deleteThread(threadId)
|
||||
.catch(console.error)
|
||||
},
|
||||
[
|
||||
cleanMessages,
|
||||
threads,
|
||||
messages,
|
||||
updateThreadLastMessage,
|
||||
updateThread,
|
||||
janDataFolderPath,
|
||||
]
|
||||
[assistants, models, requestCreateNewThread, threads]
|
||||
)
|
||||
|
||||
const deleteThread = async (threadId: string) => {
|
||||
@ -105,30 +72,27 @@ export default function useDeleteThread() {
|
||||
alert('No active thread')
|
||||
return
|
||||
}
|
||||
try {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.deleteThread(threadId)
|
||||
const availableThreads = threads.filter((c) => c.id !== threadId)
|
||||
setThreads(availableThreads)
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.deleteThread(threadId)
|
||||
.catch(console.error)
|
||||
const availableThreads = threads.filter((c) => c.id !== threadId)
|
||||
setThreads(availableThreads)
|
||||
|
||||
// delete the thread state
|
||||
deleteThreadState(threadId)
|
||||
// delete the thread state
|
||||
deleteThreadState(threadId)
|
||||
|
||||
deleteMessages(threadId)
|
||||
setCurrentPrompt('')
|
||||
toaster({
|
||||
title: 'Thread successfully deleted.',
|
||||
description: `Thread ${threadId} has been successfully deleted.`,
|
||||
type: 'success',
|
||||
})
|
||||
if (availableThreads.length > 0) {
|
||||
setActiveThreadId(availableThreads[0].id)
|
||||
} else {
|
||||
setActiveThreadId(undefined)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
deleteMessages(threadId)
|
||||
setCurrentPrompt('')
|
||||
toaster({
|
||||
title: 'Thread successfully deleted.',
|
||||
description: `Thread ${threadId} has been successfully deleted.`,
|
||||
type: 'success',
|
||||
})
|
||||
if (availableThreads.length > 0) {
|
||||
setActiveThreadId(availableThreads[0].id)
|
||||
} else {
|
||||
setActiveThreadId(undefined)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import { openFileExplorer, joinPath, baseName } from '@janhq/core'
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
@ -9,13 +10,14 @@ export const usePath = () => {
|
||||
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const onRevealInFinder = async (type: string) => {
|
||||
// TODO: this logic should be refactored.
|
||||
if (type !== 'Model' && !activeThread) return
|
||||
|
||||
let filePath = undefined
|
||||
const assistantId = activeThread?.assistants[0]?.assistant_id
|
||||
const assistantId = activeAssistant?.assistant_id
|
||||
switch (type) {
|
||||
case 'Engine':
|
||||
case 'Thread':
|
||||
|
||||
@ -6,6 +6,7 @@ import { atom, useAtomValue } from 'jotai'
|
||||
|
||||
import { activeModelAtom } from './useActiveModel'
|
||||
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
@ -28,6 +29,7 @@ export default function useRecommendedModel() {
|
||||
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
|
||||
const models = downloadedModels.sort((a, b) =>
|
||||
@ -45,8 +47,8 @@ export default function useRecommendedModel() {
|
||||
> => {
|
||||
const models = await getAndSortDownloadedModels()
|
||||
|
||||
if (!activeThread) return
|
||||
const modelId = activeThread.assistants[0]?.model.id
|
||||
if (!activeThread || !activeAssistant) return
|
||||
const modelId = activeAssistant.model.id
|
||||
const model = models.find((model) => model.id === modelId)
|
||||
|
||||
if (model) {
|
||||
|
||||
@ -10,6 +10,7 @@ import {
|
||||
ConversationalExtension,
|
||||
EngineManager,
|
||||
ToolManager,
|
||||
ThreadAssistantInfo,
|
||||
} from '@janhq/core'
|
||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
@ -28,6 +29,7 @@ import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
import { useActiveModel } from './useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import {
|
||||
addNewMessageAtom,
|
||||
deleteMessageAtom,
|
||||
@ -48,6 +50,7 @@ export const reloadModelAtom = atom(false)
|
||||
|
||||
export default function useSendChatMessage() {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||
@ -68,6 +71,7 @@ export default function useSendChatMessage() {
|
||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||
const activeThreadRef = useRef<Thread | undefined>()
|
||||
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
|
||||
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||
|
||||
const selectedModelRef = useRef<Model | undefined>()
|
||||
@ -84,36 +88,37 @@ export default function useSendChatMessage() {
|
||||
selectedModelRef.current = selectedModel
|
||||
}, [selectedModel])
|
||||
|
||||
const resendChatMessage = async (currentMessage: ThreadMessage) => {
|
||||
useEffect(() => {
|
||||
activeAssistantRef.current = activeAssistant
|
||||
}, [activeAssistant])
|
||||
|
||||
const resendChatMessage = async () => {
|
||||
// Delete last response before regenerating
|
||||
const newConvoData = currentMessages
|
||||
let toSendMessage = currentMessage
|
||||
const newConvoData = Array.from(currentMessages)
|
||||
let toSendMessage = newConvoData.pop()
|
||||
|
||||
do {
|
||||
deleteMessage(currentMessage.id)
|
||||
const msg = newConvoData.pop()
|
||||
if (!msg) break
|
||||
toSendMessage = msg
|
||||
deleteMessage(toSendMessage.id ?? '')
|
||||
} while (toSendMessage.role !== ChatCompletionRole.User)
|
||||
|
||||
if (activeThreadRef.current) {
|
||||
while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.writeMessages(activeThreadRef.current.id, newConvoData)
|
||||
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
|
||||
.catch(console.error)
|
||||
deleteMessage(toSendMessage.id ?? '')
|
||||
toSendMessage = newConvoData.pop()
|
||||
}
|
||||
|
||||
sendChatMessage(toSendMessage.content[0]?.text.value)
|
||||
if (toSendMessage?.content[0]?.text?.value)
|
||||
sendChatMessage(toSendMessage.content[0].text.value, true)
|
||||
}
|
||||
|
||||
const sendChatMessage = async (
|
||||
message: string,
|
||||
isResend: boolean = false,
|
||||
messages?: ThreadMessage[]
|
||||
) => {
|
||||
if (!message || message.trim().length === 0) return
|
||||
|
||||
if (!activeThreadRef.current) {
|
||||
console.error('No active thread')
|
||||
if (!activeThreadRef.current || !activeAssistantRef.current) {
|
||||
console.error('No active thread or assistant')
|
||||
return
|
||||
}
|
||||
|
||||
@ -139,11 +144,11 @@ export default function useSendChatMessage() {
|
||||
}
|
||||
|
||||
const modelRequest =
|
||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
||||
selectedModelRef?.current ?? activeAssistantRef.current?.model
|
||||
|
||||
// Fallback support for previous broken threads
|
||||
if (activeThreadRef.current?.assistants[0]?.model?.id === '*') {
|
||||
activeThreadRef.current.assistants[0].model = {
|
||||
if (activeAssistantRef.current?.model?.id === '*') {
|
||||
activeAssistantRef.current.model = {
|
||||
id: modelRequest.id,
|
||||
settings: modelRequest.settings,
|
||||
parameters: modelRequest.parameters,
|
||||
@ -163,46 +168,49 @@ export default function useSendChatMessage() {
|
||||
},
|
||||
activeThreadRef.current,
|
||||
messages ?? currentMessages
|
||||
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
|
||||
).addSystemMessage(activeAssistantRef.current?.instructions)
|
||||
|
||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
|
||||
if (!isResend) {
|
||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
|
||||
|
||||
// Build Thread Message to persist
|
||||
const threadMessageBuilder = new ThreadMessageBuilder(
|
||||
requestBuilder
|
||||
).pushMessage(prompt, base64Blob, fileUpload)
|
||||
// Build Thread Message to persist
|
||||
const threadMessageBuilder = new ThreadMessageBuilder(
|
||||
requestBuilder
|
||||
).pushMessage(prompt, base64Blob, fileUpload)
|
||||
|
||||
const newMessage = threadMessageBuilder.build()
|
||||
const newMessage = threadMessageBuilder.build()
|
||||
|
||||
// Push to states
|
||||
addNewMessage(newMessage)
|
||||
// Update thread state
|
||||
const updatedThread: Thread = {
|
||||
...activeThreadRef.current,
|
||||
updated: newMessage.created,
|
||||
metadata: {
|
||||
...activeThreadRef.current.metadata,
|
||||
lastMessage: prompt,
|
||||
},
|
||||
}
|
||||
updateThread(updatedThread)
|
||||
|
||||
// Update thread state
|
||||
const updatedThread: Thread = {
|
||||
...activeThreadRef.current,
|
||||
updated: newMessage.created,
|
||||
metadata: {
|
||||
...activeThreadRef.current.metadata,
|
||||
lastMessage: prompt,
|
||||
},
|
||||
// Add message
|
||||
const createdMessage = await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.createMessage(newMessage)
|
||||
|
||||
if (!createdMessage) return
|
||||
|
||||
// Push to states
|
||||
addNewMessage(createdMessage)
|
||||
}
|
||||
updateThread(updatedThread)
|
||||
|
||||
// Add message
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.addNewMessage(newMessage)
|
||||
|
||||
// Start Model if not started
|
||||
const modelId =
|
||||
selectedModelRef.current?.id ??
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
|
||||
|
||||
if (base64Blob) {
|
||||
setFileUpload([])
|
||||
}
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
if (modelRef.current?.id !== modelId && modelId) {
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
if (error) {
|
||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { useSetAtom } from 'jotai'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
readyThreadsMessagesAtom,
|
||||
setConvoMessagesAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||
import {
|
||||
setActiveThreadIdAtom,
|
||||
setThreadModelParamsAtom,
|
||||
@ -17,21 +15,27 @@ export default function useSetActiveThread() {
|
||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom)
|
||||
const setActiveAssistant = useSetAtom(activeAssistantAtom)
|
||||
|
||||
const setActiveThread = async (thread: Thread) => {
|
||||
// Load local messages only if there are no messages in the state
|
||||
if (!readyMessageThreads[thread?.id]) {
|
||||
const messages = await getLocalThreadMessage(thread?.id)
|
||||
setThreadMessage(thread?.id, messages)
|
||||
}
|
||||
if (!thread?.id) return
|
||||
|
||||
setActiveThreadId(thread?.id)
|
||||
const modelParams: ModelParams = {
|
||||
...thread?.assistants[0]?.model?.parameters,
|
||||
...thread?.assistants[0]?.model?.settings,
|
||||
|
||||
try {
|
||||
const assistantInfo = await getThreadAssistant(thread.id)
|
||||
setActiveAssistant(assistantInfo)
|
||||
// Load local messages only if there are no messages in the state
|
||||
const messages = await getLocalThreadMessage(thread.id).catch(() => [])
|
||||
const modelParams: ModelParams = {
|
||||
...assistantInfo?.model?.parameters,
|
||||
...assistantInfo?.model?.settings,
|
||||
}
|
||||
setThreadModelParams(thread?.id, modelParams)
|
||||
setThreadMessage(thread.id, messages)
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
}
|
||||
setThreadModelParams(thread?.id, modelParams)
|
||||
}
|
||||
|
||||
return { setActiveThread }
|
||||
@ -40,4 +44,9 @@ export default function useSetActiveThread() {
|
||||
const getLocalThreadMessage = async (threadId: string) =>
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.getAllMessages(threadId) ?? []
|
||||
?.listMessages(threadId) ?? []
|
||||
|
||||
const getThreadAssistant = async (threadId: string) =>
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.getThreadAssistant(threadId)
|
||||
|
||||
@ -68,6 +68,6 @@ const useThreads = () => {
|
||||
const getLocalThreads = async (): Promise<Thread[]> =>
|
||||
(await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.getThreads()) ?? []
|
||||
?.listThreads()) ?? []
|
||||
|
||||
export default useThreads
|
||||
|
||||
@ -12,7 +12,10 @@ import {
|
||||
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { useDebouncedCallback } from 'use-debounce'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
getActiveThreadModelParamsAtom,
|
||||
@ -29,11 +32,28 @@ export type UpdateModelParameter = {
|
||||
|
||||
export default function useUpdateModelParameters() {
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
|
||||
const [selectedModel] = useAtom(selectedModelAtom)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
|
||||
const updateAssistantExtension = (
|
||||
threadId: string,
|
||||
assistant: ThreadAssistantInfo
|
||||
) => {
|
||||
return extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.modifyThreadAssistant(threadId, assistant)
|
||||
}
|
||||
|
||||
const updateAssistantCallback = useDebouncedCallback(
|
||||
updateAssistantExtension,
|
||||
300
|
||||
)
|
||||
|
||||
const updateModelParameter = useCallback(
|
||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
||||
if (!activeAssistant) return
|
||||
|
||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||
const updatedModelParams = settings.modelId
|
||||
? toUpdateSettings
|
||||
@ -48,30 +68,33 @@ export default function useUpdateModelParameters() {
|
||||
setThreadModelParams(thread.id, updatedModelParams)
|
||||
const runtimeParams = extractInferenceParams(updatedModelParams)
|
||||
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||
|
||||
const assistants = thread.assistants.map(
|
||||
(assistant: ThreadAssistantInfo) => {
|
||||
assistant.model.parameters = runtimeParams
|
||||
assistant.model.settings = settingParams
|
||||
if (selectedModel) {
|
||||
assistant.model.id = settings.modelId ?? selectedModel?.id
|
||||
assistant.model.engine = settings.engine ?? selectedModel?.engine
|
||||
}
|
||||
return assistant
|
||||
}
|
||||
)
|
||||
|
||||
// update thread
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
assistants,
|
||||
const assistantInfo = {
|
||||
...activeAssistant,
|
||||
model: {
|
||||
...activeAssistant?.model,
|
||||
parameters: runtimeParams,
|
||||
settings: settingParams,
|
||||
id: settings.modelId ?? selectedModel?.id ?? activeAssistant.model.id,
|
||||
engine:
|
||||
settings.engine ??
|
||||
selectedModel?.engine ??
|
||||
activeAssistant.model.engine,
|
||||
},
|
||||
}
|
||||
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
setActiveAssistant(assistantInfo)
|
||||
updateAssistantCallback(thread.id, assistantInfo)
|
||||
},
|
||||
[activeModelParams, selectedModel, setThreadModelParams]
|
||||
[
|
||||
activeAssistant,
|
||||
selectedModel?.parameters,
|
||||
selectedModel?.settings,
|
||||
selectedModel?.id,
|
||||
selectedModel?.engine,
|
||||
activeModelParams,
|
||||
setThreadModelParams,
|
||||
setActiveAssistant,
|
||||
updateAssistantCallback,
|
||||
]
|
||||
)
|
||||
|
||||
const processStopWords = (params: ModelParams): ModelParams => {
|
||||
|
||||
@ -8,6 +8,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
|
||||
import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent'
|
||||
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
engineParamsUpdateAtom,
|
||||
@ -19,13 +20,14 @@ type Props = {
|
||||
|
||||
const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const { updateThreadMetadata } = useCreateNewThread()
|
||||
const { stopModel } = useActiveModel()
|
||||
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
|
||||
|
||||
const onValueChanged = useCallback(
|
||||
(key: string, value: string | number | boolean | string[]) => {
|
||||
if (!activeThread) return
|
||||
if (!activeThread || !activeAssistant) return
|
||||
const shouldReloadModel =
|
||||
componentData.find((x) => x.key === key)?.requireModelReload ?? false
|
||||
if (shouldReloadModel) {
|
||||
@ -34,40 +36,40 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
||||
}
|
||||
|
||||
if (
|
||||
activeThread.assistants[0].tools &&
|
||||
activeAssistant?.tools &&
|
||||
(key === 'chunk_overlap' || key === 'chunk_size')
|
||||
) {
|
||||
if (
|
||||
activeThread.assistants[0].tools[0]?.settings?.chunk_size <
|
||||
activeThread.assistants[0].tools[0]?.settings?.chunk_overlap
|
||||
activeAssistant.tools[0]?.settings?.chunk_size <
|
||||
activeAssistant.tools[0]?.settings?.chunk_overlap
|
||||
) {
|
||||
activeThread.assistants[0].tools[0].settings.chunk_overlap =
|
||||
activeThread.assistants[0].tools[0].settings.chunk_size
|
||||
activeAssistant.tools[0].settings.chunk_overlap =
|
||||
activeAssistant.tools[0].settings.chunk_size
|
||||
}
|
||||
if (
|
||||
key === 'chunk_size' &&
|
||||
value < activeThread.assistants[0].tools[0].settings?.chunk_overlap
|
||||
value < activeAssistant.tools[0].settings?.chunk_overlap
|
||||
) {
|
||||
activeThread.assistants[0].tools[0].settings.chunk_overlap = value
|
||||
activeAssistant.tools[0].settings.chunk_overlap = value
|
||||
} else if (
|
||||
key === 'chunk_overlap' &&
|
||||
value > activeThread.assistants[0].tools[0].settings?.chunk_size
|
||||
value > activeAssistant.tools[0].settings?.chunk_size
|
||||
) {
|
||||
activeThread.assistants[0].tools[0].settings.chunk_size = value
|
||||
activeAssistant.tools[0].settings.chunk_size = value
|
||||
}
|
||||
}
|
||||
updateThreadMetadata({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
...activeAssistant,
|
||||
tools: [
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: true,
|
||||
settings: {
|
||||
...(activeThread.assistants[0].tools &&
|
||||
activeThread.assistants[0].tools[0]?.settings),
|
||||
...(activeAssistant.tools &&
|
||||
activeAssistant.tools[0]?.settings),
|
||||
[key]: value,
|
||||
},
|
||||
},
|
||||
@ -77,6 +79,7 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
||||
})
|
||||
},
|
||||
[
|
||||
activeAssistant,
|
||||
activeThread,
|
||||
componentData,
|
||||
setEngineParamsUpdate,
|
||||
|
||||
@ -33,6 +33,7 @@ import RichTextEditor from './RichTextEditor'
|
||||
|
||||
import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
|
||||
@ -67,6 +68,7 @@ const ChatInput = () => {
|
||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const { stopInference } = useActiveModel()
|
||||
|
||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||
@ -153,9 +155,9 @@ const ChatInput = () => {
|
||||
onClick={(e) => {
|
||||
if (
|
||||
fileUpload.length > 0 ||
|
||||
(activeThread?.assistants[0].tools &&
|
||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
||||
!activeThread?.assistants[0].model.settings?.vision_model)
|
||||
(activeAssistant?.tools &&
|
||||
!activeAssistant?.tools[0]?.enabled &&
|
||||
!activeAssistant?.model.settings?.vision_model)
|
||||
) {
|
||||
e.stopPropagation()
|
||||
} else {
|
||||
@ -171,16 +173,15 @@ const ChatInput = () => {
|
||||
}
|
||||
disabled={
|
||||
isModelSupportRagAndTools &&
|
||||
activeThread?.assistants[0].tools &&
|
||||
activeThread?.assistants[0].tools[0]?.enabled
|
||||
activeAssistant?.tools &&
|
||||
activeAssistant?.tools[0]?.enabled
|
||||
}
|
||||
content={
|
||||
<>
|
||||
{fileUpload.length > 0 ||
|
||||
(activeThread?.assistants[0].tools &&
|
||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
||||
!activeThread?.assistants[0].model.settings
|
||||
?.vision_model && (
|
||||
(activeAssistant?.tools &&
|
||||
!activeAssistant?.tools[0]?.enabled &&
|
||||
!activeAssistant?.model.settings?.vision_model && (
|
||||
<>
|
||||
{fileUpload.length !== 0 && (
|
||||
<span>
|
||||
@ -188,9 +189,8 @@ const ChatInput = () => {
|
||||
time.
|
||||
</span>
|
||||
)}
|
||||
{activeThread?.assistants[0].tools &&
|
||||
activeThread?.assistants[0].tools[0]?.enabled ===
|
||||
false &&
|
||||
{activeAssistant?.tools &&
|
||||
activeAssistant?.tools[0]?.enabled === false &&
|
||||
isModelSupportRagAndTools && (
|
||||
<span>
|
||||
Turn on Retrieval in Tools settings to use this
|
||||
@ -221,14 +221,12 @@ const ChatInput = () => {
|
||||
<li
|
||||
className={twMerge(
|
||||
'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]',
|
||||
activeThread?.assistants[0].model.settings?.vision_model
|
||||
activeAssistant?.model.settings?.vision_model
|
||||
? 'cursor-pointer'
|
||||
: 'cursor-not-allowed opacity-50'
|
||||
)}
|
||||
onClick={() => {
|
||||
if (
|
||||
activeThread?.assistants[0].model.settings?.vision_model
|
||||
) {
|
||||
if (activeAssistant?.model.settings?.vision_model) {
|
||||
imageInputRef.current?.click()
|
||||
setShowAttacmentMenus(false)
|
||||
}
|
||||
@ -239,9 +237,7 @@ const ChatInput = () => {
|
||||
</li>
|
||||
}
|
||||
content="This feature only supports multimodal models."
|
||||
disabled={
|
||||
activeThread?.assistants[0].model.settings?.vision_model
|
||||
}
|
||||
disabled={activeAssistant?.model.settings?.vision_model}
|
||||
/>
|
||||
<Tooltip
|
||||
side="bottom"
|
||||
@ -261,8 +257,8 @@ const ChatInput = () => {
|
||||
</li>
|
||||
}
|
||||
content={
|
||||
(!activeThread?.assistants[0].tools ||
|
||||
!activeThread?.assistants[0].tools[0]?.enabled) && (
|
||||
(!activeAssistant?.tools ||
|
||||
!activeAssistant?.tools[0]?.enabled) && (
|
||||
<span>
|
||||
Turn on Retrieval in Assistant Settings to use this
|
||||
feature.
|
||||
|
||||
@ -80,19 +80,17 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
|
||||
setEditMessage('')
|
||||
const messageIdx = messages.findIndex((msg) => msg.id === message.id)
|
||||
const newMessages = messages.slice(0, messageIdx)
|
||||
if (activeThread) {
|
||||
setMessages(activeThread.id, newMessages)
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.writeMessages(
|
||||
activeThread.id,
|
||||
// Remove all of the messages below this
|
||||
newMessages
|
||||
)
|
||||
.then(() => {
|
||||
sendChatMessage(editPrompt, newMessages)
|
||||
})
|
||||
}
|
||||
const toDeleteMessages = messages.slice(messageIdx)
|
||||
const threadId = messages[0].thread_id
|
||||
await Promise.all(
|
||||
toDeleteMessages.map(async (message) =>
|
||||
extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.deleteMessage(message.thread_id, message.id)
|
||||
)
|
||||
)
|
||||
setMessages(threadId, newMessages)
|
||||
sendChatMessage(editPrompt, false, newMessages)
|
||||
}
|
||||
|
||||
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
|
||||
@ -10,15 +10,15 @@ import { MainViewState } from '@/constants/screens'
|
||||
import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const LoadModelError = () => {
|
||||
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
||||
const loadModelError = useAtomValue(loadModelErrorAtom)
|
||||
const setMainState = useSetAtom(mainViewStateAtom)
|
||||
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const ErrorMessage = () => {
|
||||
if (
|
||||
@ -33,9 +33,9 @@ const LoadModelError = () => {
|
||||
className="cursor-pointer font-medium text-[hsla(var(--app-link))]"
|
||||
onClick={() => {
|
||||
setMainState(MainViewState.Settings)
|
||||
if (activeThread?.assistants[0]?.model.engine) {
|
||||
if (activeAssistant?.model.engine) {
|
||||
const engine = EngineManager.instance().get(
|
||||
activeThread.assistants[0].model.engine
|
||||
activeAssistant.model.engine
|
||||
)
|
||||
engine?.name && setSelectedSettingScreen(engine.name)
|
||||
}
|
||||
|
||||
@ -58,12 +58,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
||||
// Should also delete error messages to clear out the error state
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.writeMessages(
|
||||
thread.id,
|
||||
messages.filter(
|
||||
(msg) => msg.id !== message.id && msg.status !== MessageStatus.Error
|
||||
)
|
||||
)
|
||||
?.deleteMessage(thread.id, message.id)
|
||||
.catch(console.error)
|
||||
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
@ -89,10 +85,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
||||
setEditMessage(message.id ?? '')
|
||||
}
|
||||
|
||||
const onRegenerateClick = async () => {
|
||||
resendChatMessage(message)
|
||||
}
|
||||
|
||||
if (message.status === MessageStatus.Pending) return null
|
||||
|
||||
return (
|
||||
@ -122,7 +114,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
||||
ContentType.Pdf && (
|
||||
<div
|
||||
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
|
||||
onClick={onRegenerateClick}
|
||||
onClick={resendChatMessage}
|
||||
>
|
||||
<Tooltip
|
||||
trigger={
|
||||
|
||||
@ -17,11 +17,11 @@ import DocMessage from './DocMessage'
|
||||
import ImageMessage from './ImageMessage'
|
||||
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import {
|
||||
editMessageAtom,
|
||||
tokenSpeedAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const MessageContainer: React.FC<
|
||||
ThreadMessage & { isCurrentMessage: boolean }
|
||||
@ -29,7 +29,7 @@ const MessageContainer: React.FC<
|
||||
const isUser = props.role === ChatCompletionRole.User
|
||||
const isSystem = props.role === ChatCompletionRole.System
|
||||
const editMessage = useAtomValue(editMessageAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const tokenSpeed = useAtomValue(tokenSpeedAtom)
|
||||
|
||||
const text = useMemo(
|
||||
@ -75,10 +75,10 @@ const MessageContainer: React.FC<
|
||||
>
|
||||
{isUser
|
||||
? props.role
|
||||
: (activeThread?.assistants[0].assistant_name ?? props.role)}
|
||||
: (activeAssistant?.assistant_name ?? props.role)}
|
||||
</div>
|
||||
<p className="text-xs font-medium text-gray-400">
|
||||
{displayDate(props.created)}
|
||||
{props.created && displayDate(props.created ?? new Date())}
|
||||
</p>
|
||||
{tokenSpeed &&
|
||||
tokenSpeed.message === props.id &&
|
||||
|
||||
@ -27,6 +27,7 @@ import RequestDownloadModel from './RequestDownloadModel'
|
||||
|
||||
import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom'
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
import {
|
||||
@ -55,9 +56,9 @@ const ThreadCenterPanel = () => {
|
||||
const setFileUpload = useSetAtom(fileUploadAtom)
|
||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
|
||||
const acceptedFormat: Accept = activeThread?.assistants[0].model.settings
|
||||
?.vision_model
|
||||
const acceptedFormat: Accept = activeAssistant?.model.settings?.vision_model
|
||||
? {
|
||||
'application/pdf': ['.pdf'],
|
||||
'image/jpeg': ['.jpeg'],
|
||||
@ -78,14 +79,13 @@ const ThreadCenterPanel = () => {
|
||||
if (!experimentalFeature) return
|
||||
if (
|
||||
e.dataTransfer.items.length === 1 &&
|
||||
((activeThread?.assistants[0].tools &&
|
||||
activeThread?.assistants[0].tools[0]?.enabled) ||
|
||||
activeThread?.assistants[0].model.settings?.vision_model)
|
||||
((activeAssistant?.tools && activeAssistant?.tools[0]?.enabled) ||
|
||||
activeAssistant?.model.settings?.vision_model)
|
||||
) {
|
||||
setDragOver(true)
|
||||
} else if (
|
||||
activeThread?.assistants[0].tools &&
|
||||
!activeThread?.assistants[0].tools[0]?.enabled
|
||||
activeAssistant?.tools &&
|
||||
!activeAssistant?.tools[0]?.enabled
|
||||
) {
|
||||
setDragRejected({ code: 'retrieval-off' })
|
||||
} else {
|
||||
@ -100,9 +100,9 @@ const ThreadCenterPanel = () => {
|
||||
!files ||
|
||||
files.length !== 1 ||
|
||||
rejectFiles.length !== 0 ||
|
||||
(activeThread?.assistants[0].tools &&
|
||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
||||
!activeThread?.assistants[0].model.settings?.vision_model)
|
||||
(activeAssistant?.tools &&
|
||||
!activeAssistant?.tools[0]?.enabled &&
|
||||
!activeAssistant?.model.settings?.vision_model)
|
||||
)
|
||||
return
|
||||
const imageType = files[0]?.type.includes('image')
|
||||
@ -110,10 +110,7 @@ const ThreadCenterPanel = () => {
|
||||
setDragOver(false)
|
||||
},
|
||||
onDropRejected: (e) => {
|
||||
if (
|
||||
activeThread?.assistants[0].tools &&
|
||||
!activeThread?.assistants[0].tools[0]?.enabled
|
||||
) {
|
||||
if (activeAssistant?.tools && !activeAssistant?.tools[0]?.enabled) {
|
||||
setDragRejected({ code: 'retrieval-off' })
|
||||
} else {
|
||||
setDragRejected({ code: e[0].errors[0].code })
|
||||
@ -186,8 +183,7 @@ const ThreadCenterPanel = () => {
|
||||
<h6 className="font-bold">
|
||||
{isDragReject
|
||||
? `Currently, we only support 1 attachment at the same time with ${
|
||||
activeThread?.assistants[0].model.settings
|
||||
?.vision_model
|
||||
activeAssistant?.model.settings?.vision_model
|
||||
? 'PDF, JPEG, JPG, PNG'
|
||||
: 'PDF'
|
||||
} format`
|
||||
@ -195,7 +191,7 @@ const ThreadCenterPanel = () => {
|
||||
</h6>
|
||||
{!isDragReject && (
|
||||
<p className="mt-2">
|
||||
{activeThread?.assistants[0].model.settings?.vision_model
|
||||
{activeAssistant?.model.settings?.vision_model
|
||||
? 'PDF, JPEG, JPG, PNG'
|
||||
: 'PDF'}
|
||||
</p>
|
||||
|
||||
@ -15,13 +15,15 @@ const ModalEditTitleThread = () => {
|
||||
const [modalActionThread, setModalActionThread] = useAtom(
|
||||
modalActionThreadAtom
|
||||
)
|
||||
const [title, setTitle] = useState(modalActionThread.thread?.title as string)
|
||||
const [title, setTitle] = useState(
|
||||
modalActionThread.thread?.metadata?.title as string
|
||||
)
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (modalActionThread.thread?.title) {
|
||||
setTitle(modalActionThread.thread?.title)
|
||||
if (modalActionThread.thread?.metadata?.title) {
|
||||
setTitle(modalActionThread.thread?.metadata?.title as string)
|
||||
}
|
||||
}, [modalActionThread.thread?.title])
|
||||
}, [modalActionThread.thread?.metadata])
|
||||
|
||||
const onUpdateTitle = useCallback(
|
||||
(e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
|
||||
@ -30,6 +32,10 @@ const ModalEditTitleThread = () => {
|
||||
updateThreadMetadata({
|
||||
...modalActionThread?.thread,
|
||||
title: title || 'New Thread',
|
||||
metadata: {
|
||||
...modalActionThread?.thread.metadata,
|
||||
title: title || 'New Thread',
|
||||
},
|
||||
})
|
||||
},
|
||||
[modalActionThread?.thread, title, updateThreadMetadata]
|
||||
|
||||
@ -20,7 +20,10 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
import useRecommendedModel from '@/hooks/useRecommendedModel'
|
||||
import useSetActiveThread from '@/hooks/useSetActiveThread'
|
||||
|
||||
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import {
|
||||
activeAssistantAtom,
|
||||
assistantsAtom,
|
||||
} from '@/helpers/atoms/Assistant.atom'
|
||||
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||
|
||||
import {
|
||||
@ -34,6 +37,7 @@ import {
|
||||
const ThreadLeftPanel = () => {
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const { setActiveThread } = useSetActiveThread()
|
||||
const assistants = useAtomValue(assistantsAtom)
|
||||
const threadDataReady = useAtomValue(threadDataReadyAtom)
|
||||
@ -67,6 +71,7 @@ const ThreadLeftPanel = () => {
|
||||
useEffect(() => {
|
||||
if (
|
||||
threadDataReady &&
|
||||
activeAssistant &&
|
||||
assistants.length > 0 &&
|
||||
threads.length === 0 &&
|
||||
downloadedModels.length > 0
|
||||
@ -75,7 +80,13 @@ const ThreadLeftPanel = () => {
|
||||
(model) => model.engine === InferenceEngine.cortex_llamacpp
|
||||
)
|
||||
const selectedModel = model[0] || recommendedModel
|
||||
requestCreateNewThread(assistants[0], selectedModel)
|
||||
requestCreateNewThread(
|
||||
{
|
||||
...assistants[0],
|
||||
...activeAssistant,
|
||||
},
|
||||
selectedModel
|
||||
)
|
||||
} else if (threadDataReady && !activeThreadId) {
|
||||
setActiveThread(threads[0])
|
||||
}
|
||||
@ -88,6 +99,7 @@ const ThreadLeftPanel = () => {
|
||||
setActiveThread,
|
||||
recommendedModel,
|
||||
downloadedModels,
|
||||
activeAssistant,
|
||||
])
|
||||
|
||||
const onContextMenu = (event: React.MouseEvent, thread: Thread) => {
|
||||
@ -138,7 +150,7 @@ const ThreadLeftPanel = () => {
|
||||
activeThreadId && 'font-medium'
|
||||
)}
|
||||
>
|
||||
{thread.title}
|
||||
{thread.title ?? thread.metadata?.title}
|
||||
</h1>
|
||||
</div>
|
||||
<div
|
||||
|
||||
@ -14,48 +14,54 @@ import AssistantSetting from '@/screens/Thread/ThreadCenterPanel/AssistantSettin
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const Tools = () => {
|
||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||
const { updateThreadMetadata } = useCreateNewThread()
|
||||
const { recommendedModel, downloadedModels } = useRecommendedModel()
|
||||
|
||||
const componentDataAssistantSetting = getConfigurationsData(
|
||||
(activeThread?.assistants[0]?.tools &&
|
||||
activeThread?.assistants[0]?.tools[0]?.settings) ??
|
||||
{}
|
||||
(activeAssistant?.tools && activeAssistant?.tools[0]?.settings) ?? {}
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (!activeThread) return
|
||||
let model = downloadedModels.find(
|
||||
(model) => model.id === activeThread.assistants[0].model.id
|
||||
(model) => model.id === activeAssistant?.model.id
|
||||
)
|
||||
if (!model) {
|
||||
model = recommendedModel
|
||||
}
|
||||
setSelectedModel(model)
|
||||
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
|
||||
}, [
|
||||
recommendedModel,
|
||||
activeThread,
|
||||
downloadedModels,
|
||||
setSelectedModel,
|
||||
activeAssistant?.model.id,
|
||||
])
|
||||
|
||||
const onRetrievalSwitchUpdate = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (!activeThread) return
|
||||
if (!activeThread || !activeAssistant) return
|
||||
updateThreadMetadata({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
...activeAssistant,
|
||||
tools: [
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: enabled,
|
||||
settings:
|
||||
(activeThread.assistants[0].tools &&
|
||||
activeThread.assistants[0].tools[0]?.settings) ??
|
||||
(activeAssistant.tools &&
|
||||
activeAssistant.tools[0]?.settings) ??
|
||||
{},
|
||||
},
|
||||
],
|
||||
@ -63,25 +69,25 @@ const Tools = () => {
|
||||
],
|
||||
})
|
||||
},
|
||||
[activeThread, updateThreadMetadata]
|
||||
[activeAssistant, activeThread, updateThreadMetadata]
|
||||
)
|
||||
|
||||
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (!activeThread) return
|
||||
if (!activeThread || !activeAssistant) return
|
||||
updateThreadMetadata({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
...activeAssistant,
|
||||
tools: [
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: true,
|
||||
useTimeWeightedRetriever: enabled,
|
||||
settings:
|
||||
(activeThread.assistants[0].tools &&
|
||||
activeThread.assistants[0].tools[0]?.settings) ??
|
||||
(activeAssistant.tools &&
|
||||
activeAssistant.tools[0]?.settings) ??
|
||||
{},
|
||||
},
|
||||
],
|
||||
@ -89,23 +95,54 @@ const Tools = () => {
|
||||
],
|
||||
})
|
||||
},
|
||||
[activeThread, updateThreadMetadata]
|
||||
[activeAssistant, activeThread, updateThreadMetadata]
|
||||
)
|
||||
|
||||
if (!experimentalFeature) return null
|
||||
|
||||
return (
|
||||
<Fragment>
|
||||
{activeThread?.assistants[0]?.tools &&
|
||||
componentDataAssistantSetting.length > 0 && (
|
||||
<div className="p-4">
|
||||
<div className="mb-2">
|
||||
{activeAssistant?.tools && componentDataAssistantSetting.length > 0 && (
|
||||
<div className="p-4">
|
||||
<div className="mb-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label
|
||||
id="retrieval"
|
||||
className="inline-flex items-center font-medium"
|
||||
>
|
||||
Retrieval
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
size={16}
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Retrieval helps the assistant use information from
|
||||
files you send to it. Once you share a file, the
|
||||
assistant automatically fetches the relevant content
|
||||
based on your request."
|
||||
/>
|
||||
</label>
|
||||
<div className="flex items-center justify-between">
|
||||
<label
|
||||
id="retrieval"
|
||||
className="inline-flex items-center font-medium"
|
||||
>
|
||||
Retrieval
|
||||
<Switch
|
||||
name="retrieval"
|
||||
checked={activeAssistant?.tools[0].enabled}
|
||||
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{activeAssistant?.tools[0].enabled && (
|
||||
<div className="pb-4 pt-2">
|
||||
<div className="mb-4">
|
||||
<div className="item-center mb-2 flex">
|
||||
<label
|
||||
id="embedding-model"
|
||||
className="inline-flex font-medium"
|
||||
>
|
||||
Embedding Model
|
||||
</label>
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
@ -113,90 +150,26 @@ const Tools = () => {
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Retrieval helps the assistant use information from
|
||||
files you send to it. Once you share a file, the
|
||||
assistant automatically fetches the relevant content
|
||||
based on your request."
|
||||
/>
|
||||
</label>
|
||||
<div className="flex items-center justify-between">
|
||||
<Switch
|
||||
name="retrieval"
|
||||
checked={activeThread?.assistants[0].tools[0].enabled}
|
||||
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{activeThread?.assistants[0]?.tools[0].enabled && (
|
||||
<div className="pb-4 pt-2">
|
||||
<div className="mb-4">
|
||||
<div className="item-center mb-2 flex">
|
||||
<label
|
||||
id="embedding-model"
|
||||
className="inline-flex font-medium"
|
||||
>
|
||||
Embedding Model
|
||||
</label>
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
size={16}
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Embedding model is crucial for understanding and
|
||||
content="Embedding model is crucial for understanding and
|
||||
processing the input text effectively by
|
||||
converting text to numerical representations.
|
||||
Align the model choice with your task, evaluate
|
||||
its performance, and consider factors like
|
||||
resource availability. Experiment to find the best
|
||||
fit for your specific use case."
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<Input
|
||||
value={selectedModel?.name || ''}
|
||||
disabled
|
||||
readOnly
|
||||
/>
|
||||
</div>
|
||||
/>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center">
|
||||
<label
|
||||
id="vector-database"
|
||||
className="inline-flex items-center font-medium"
|
||||
>
|
||||
Vector Database
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
size={16}
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Vector Database is crucial for efficient storage
|
||||
and retrieval of embeddings. Consider your
|
||||
specific task, available resources, and language
|
||||
requirements. Experiment to find the best fit for
|
||||
your specific use case."
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
<Input value="HNSWLib" disabled readOnly />
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<Input value={selectedModel?.name || ''} disabled readOnly />
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center">
|
||||
<label
|
||||
id="use-time-weighted-retriever"
|
||||
className="inline-block font-medium"
|
||||
>
|
||||
Time-Weighted Retrieval?
|
||||
</label>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center">
|
||||
<label
|
||||
id="vector-database"
|
||||
className="inline-flex items-center font-medium"
|
||||
>
|
||||
Vector Database
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
@ -204,33 +177,59 @@ const Tools = () => {
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Time-Weighted Retriever looks at how similar
|
||||
content="Vector Database is crucial for efficient storage
|
||||
and retrieval of embeddings. Consider your
|
||||
specific task, available resources, and language
|
||||
requirements. Experiment to find the best fit for
|
||||
your specific use case."
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
<Input value="HNSWLib" disabled readOnly />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center">
|
||||
<label
|
||||
id="use-time-weighted-retriever"
|
||||
className="inline-block font-medium"
|
||||
>
|
||||
Time-Weighted Retrieval?
|
||||
</label>
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
size={16}
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Time-Weighted Retriever looks at how similar
|
||||
they are and how new they are. It compares
|
||||
documents based on their meaning like usual, but
|
||||
also considers when they were added to give
|
||||
newer ones more importance."
|
||||
/>
|
||||
<div className="ml-auto flex items-center justify-between">
|
||||
<Switch
|
||||
name="use-time-weighted-retriever"
|
||||
checked={
|
||||
activeAssistant?.tools[0].useTimeWeightedRetriever ||
|
||||
false
|
||||
}
|
||||
onChange={(e) =>
|
||||
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
||||
}
|
||||
/>
|
||||
<div className="ml-auto flex items-center justify-between">
|
||||
<Switch
|
||||
name="use-time-weighted-retriever"
|
||||
checked={
|
||||
activeThread?.assistants[0].tools[0]
|
||||
.useTimeWeightedRetriever || false
|
||||
}
|
||||
onChange={(e) =>
|
||||
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<AssistantSetting
|
||||
componentData={componentDataAssistantSetting}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<AssistantSetting componentData={componentDataAssistantSetting} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</Fragment>
|
||||
)
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ import PromptTemplateSetting from './PromptTemplateSetting'
|
||||
import Tools from './Tools'
|
||||
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
@ -53,6 +54,7 @@ const ENGINE_SETTINGS = 'Engine Settings'
|
||||
|
||||
const ThreadRightPanel = () => {
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||
@ -154,18 +156,18 @@ const ThreadRightPanel = () => {
|
||||
|
||||
const onAssistantInstructionChanged = useCallback(
|
||||
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
if (activeThread)
|
||||
if (activeThread && activeAssistant)
|
||||
updateThreadMetadata({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
...activeAssistant,
|
||||
instructions: e.target.value || '',
|
||||
},
|
||||
],
|
||||
})
|
||||
},
|
||||
[activeThread, updateThreadMetadata]
|
||||
[activeAssistant, activeThread, updateThreadMetadata]
|
||||
)
|
||||
|
||||
const resetModel = useDebouncedCallback(() => {
|
||||
@ -174,7 +176,7 @@ const ThreadRightPanel = () => {
|
||||
|
||||
const onValueChanged = useCallback(
|
||||
(key: string, value: string | number | boolean | string[]) => {
|
||||
if (!activeThread) {
|
||||
if (!activeThread || !activeAssistant) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -186,32 +188,38 @@ const ThreadRightPanel = () => {
|
||||
})
|
||||
|
||||
if (
|
||||
activeThread.assistants[0].model.parameters?.max_tokens &&
|
||||
activeThread.assistants[0].model.settings?.ctx_len
|
||||
activeAssistant.model.parameters?.max_tokens &&
|
||||
activeAssistant.model.settings?.ctx_len
|
||||
) {
|
||||
if (
|
||||
key === 'max_tokens' &&
|
||||
Number(value) > activeThread.assistants[0].model.settings.ctx_len
|
||||
Number(value) > activeAssistant.model.settings.ctx_len
|
||||
) {
|
||||
updateModelParameter(activeThread, {
|
||||
params: {
|
||||
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
|
||||
max_tokens: activeAssistant.model.settings.ctx_len,
|
||||
},
|
||||
})
|
||||
}
|
||||
if (
|
||||
key === 'ctx_len' &&
|
||||
Number(value) < activeThread.assistants[0].model.parameters.max_tokens
|
||||
Number(value) < activeAssistant.model.parameters.max_tokens
|
||||
) {
|
||||
updateModelParameter(activeThread, {
|
||||
params: {
|
||||
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
|
||||
max_tokens: activeAssistant.model.settings.ctx_len,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
[activeThread, resetModel, setEngineParamsUpdate, updateModelParameter]
|
||||
[
|
||||
activeAssistant,
|
||||
activeThread,
|
||||
resetModel,
|
||||
setEngineParamsUpdate,
|
||||
updateModelParameter,
|
||||
]
|
||||
)
|
||||
|
||||
if (!activeThread) {
|
||||
@ -250,7 +258,7 @@ const ThreadRightPanel = () => {
|
||||
<TextArea
|
||||
id="assistant-instructions"
|
||||
placeholder="Eg. You are a helpful assistant."
|
||||
value={activeThread?.assistants[0].instructions ?? ''}
|
||||
value={activeAssistant?.instructions ?? ''}
|
||||
autoResize
|
||||
onChange={onAssistantInstructionChanged}
|
||||
/>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user