Merge pull request #4249 from janhq/feat/threads-messages-requests-to-backend
feat: reroute threads and messages requests to cortex.cpp backend
This commit is contained in:
commit
72c9a981ae
@ -1,4 +1,10 @@
|
|||||||
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
|
import {
|
||||||
|
Thread,
|
||||||
|
ThreadInterface,
|
||||||
|
ThreadMessage,
|
||||||
|
MessageInterface,
|
||||||
|
ThreadAssistantInfo,
|
||||||
|
} from '../../types'
|
||||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,10 +23,21 @@ export abstract class ConversationalExtension
|
|||||||
return ExtensionTypeEnum.Conversational
|
return ExtensionTypeEnum.Conversational
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract getThreads(): Promise<Thread[]>
|
abstract listThreads(): Promise<Thread[]>
|
||||||
abstract saveThread(thread: Thread): Promise<void>
|
abstract createThread(thread: Partial<Thread>): Promise<Thread>
|
||||||
|
abstract modifyThread(thread: Thread): Promise<void>
|
||||||
abstract deleteThread(threadId: string): Promise<void>
|
abstract deleteThread(threadId: string): Promise<void>
|
||||||
abstract addNewMessage(message: ThreadMessage): Promise<void>
|
abstract createMessage(message: Partial<ThreadMessage>): Promise<ThreadMessage>
|
||||||
abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
|
abstract deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||||
abstract getAllMessages(threadId: string): Promise<ThreadMessage[]>
|
abstract listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||||
|
abstract getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo>
|
||||||
|
abstract createThreadAssistant(
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
): Promise<ThreadAssistantInfo>
|
||||||
|
abstract modifyThreadAssistant(
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
): Promise<ThreadAssistantInfo>
|
||||||
|
abstract modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import { events } from '../../events'
|
|||||||
import { BaseExtension } from '../../extension'
|
import { BaseExtension } from '../../extension'
|
||||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
||||||
import { EngineManager } from './EngineManager'
|
import { EngineManager } from './EngineManager'
|
||||||
import { ModelManager } from '../../models/manager'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base AIEngine
|
* Base AIEngine
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import {
|
|||||||
mkdirSync,
|
mkdirSync,
|
||||||
appendFileSync,
|
appendFileSync,
|
||||||
createWriteStream,
|
createWriteStream,
|
||||||
rmdirSync,
|
|
||||||
} from 'fs'
|
} from 'fs'
|
||||||
import { JanApiRouteConfiguration, RouteConfiguration } from './configuration'
|
import { JanApiRouteConfiguration, RouteConfiguration } from './configuration'
|
||||||
import { join } from 'path'
|
import { join } from 'path'
|
||||||
@ -126,7 +125,7 @@ export const createThread = async (thread: any) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const threadId = generateThreadId(thread.assistants[0].assistant_id)
|
const threadId = generateThreadId(thread.assistants[0]?.assistant_id)
|
||||||
try {
|
try {
|
||||||
const updatedThread = {
|
const updatedThread = {
|
||||||
...thread,
|
...thread,
|
||||||
@ -280,7 +279,7 @@ export const models = async (request: any, reply: any) => {
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ""}`, {
|
const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ''}`, {
|
||||||
method: request.method,
|
method: request.method,
|
||||||
headers: headers,
|
headers: headers,
|
||||||
body: JSON.stringify(request.body),
|
body: JSON.stringify(request.body),
|
||||||
|
|||||||
@ -36,3 +36,10 @@ export type Assistant = {
|
|||||||
/** Represents the metadata of the object. */
|
/** Represents the metadata of the object. */
|
||||||
metadata?: Record<string, unknown>
|
metadata?: Record<string, unknown>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface CodeInterpreterTool {
|
||||||
|
/**
|
||||||
|
* The type of tool being defined: `code_interpreter`
|
||||||
|
*/
|
||||||
|
type: 'code_interpreter'
|
||||||
|
}
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { CodeInterpreterTool } from '../assistant'
|
||||||
import { ChatCompletionMessage, ChatCompletionRole } from '../inference'
|
import { ChatCompletionMessage, ChatCompletionRole } from '../inference'
|
||||||
import { ModelInfo } from '../model'
|
import { ModelInfo } from '../model'
|
||||||
import { Thread } from '../thread'
|
import { Thread } from '../thread'
|
||||||
@ -15,6 +16,10 @@ export type ThreadMessage = {
|
|||||||
thread_id: string
|
thread_id: string
|
||||||
/** The assistant id of this thread. **/
|
/** The assistant id of this thread. **/
|
||||||
assistant_id?: string
|
assistant_id?: string
|
||||||
|
/**
|
||||||
|
* A list of files attached to the message, and the tools they were added to.
|
||||||
|
*/
|
||||||
|
attachments?: Array<Attachment> | null
|
||||||
/** The role of the author of this message. **/
|
/** The role of the author of this message. **/
|
||||||
role: ChatCompletionRole
|
role: ChatCompletionRole
|
||||||
/** The content of this message. **/
|
/** The content of this message. **/
|
||||||
@ -52,6 +57,11 @@ export type MessageRequest = {
|
|||||||
*/
|
*/
|
||||||
assistantId?: string
|
assistantId?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of files attached to the message, and the tools they were added to.
|
||||||
|
*/
|
||||||
|
attachments: Array<Attachment> | null
|
||||||
|
|
||||||
/** Messages for constructing a chat completion request **/
|
/** Messages for constructing a chat completion request **/
|
||||||
messages?: ChatCompletionMessage[]
|
messages?: ChatCompletionMessage[]
|
||||||
|
|
||||||
@ -97,8 +107,7 @@ export enum ErrorCode {
|
|||||||
*/
|
*/
|
||||||
export enum ContentType {
|
export enum ContentType {
|
||||||
Text = 'text',
|
Text = 'text',
|
||||||
Image = 'image',
|
Image = 'image_url',
|
||||||
Pdf = 'pdf',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -108,8 +117,15 @@ export enum ContentType {
|
|||||||
export type ContentValue = {
|
export type ContentValue = {
|
||||||
value: string
|
value: string
|
||||||
annotations: string[]
|
annotations: string[]
|
||||||
name?: string
|
}
|
||||||
size?: number
|
|
||||||
|
/**
|
||||||
|
* The `ImageContentValue` type defines the shape of a content value object of image type
|
||||||
|
* @data_transfer_object
|
||||||
|
*/
|
||||||
|
export type ImageContentValue = {
|
||||||
|
detail?: string
|
||||||
|
url?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -118,5 +134,37 @@ export type ContentValue = {
|
|||||||
*/
|
*/
|
||||||
export type ThreadContent = {
|
export type ThreadContent = {
|
||||||
type: ContentType
|
type: ContentType
|
||||||
text: ContentValue
|
text?: ContentValue
|
||||||
|
image_url?: ImageContentValue
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Attachment {
|
||||||
|
/**
|
||||||
|
* The ID of the file to attach to the message.
|
||||||
|
*/
|
||||||
|
file_id?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tools to add this file to.
|
||||||
|
*/
|
||||||
|
tools?: Array<CodeInterpreterTool | Attachment.AssistantToolsFileSearchTypeOnly>
|
||||||
|
}
|
||||||
|
|
||||||
|
export namespace Attachment {
|
||||||
|
export interface AssistantToolsFileSearchTypeOnly {
|
||||||
|
/**
|
||||||
|
* The type of tool being defined: `file_search`
|
||||||
|
*/
|
||||||
|
type: 'file_search'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* On an incomplete message, details about why the message is incomplete.
|
||||||
|
*/
|
||||||
|
export interface IncompleteDetails {
|
||||||
|
/**
|
||||||
|
* The reason the message is incomplete.
|
||||||
|
*/
|
||||||
|
reason: 'content_filter' | 'max_tokens' | 'run_cancelled' | 'run_expired' | 'run_failed'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,20 +11,20 @@ export interface MessageInterface {
|
|||||||
* @param {ThreadMessage} message - The message to be added.
|
* @param {ThreadMessage} message - The message to be added.
|
||||||
* @returns {Promise<void>} A promise that resolves when the message has been added.
|
* @returns {Promise<void>} A promise that resolves when the message has been added.
|
||||||
*/
|
*/
|
||||||
addNewMessage(message: ThreadMessage): Promise<void>
|
createMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||||
|
|
||||||
/**
|
|
||||||
* Writes an array of messages to a specific thread.
|
|
||||||
* @param {string} threadId - The ID of the thread to write the messages to.
|
|
||||||
* @param {ThreadMessage[]} messages - The array of messages to be written.
|
|
||||||
* @returns {Promise<void>} A promise that resolves when the messages have been written.
|
|
||||||
*/
|
|
||||||
writeMessages(threadId: string, messages: ThreadMessage[]): Promise<void>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves all messages from a specific thread.
|
* Retrieves all messages from a specific thread.
|
||||||
* @param {string} threadId - The ID of the thread to retrieve the messages from.
|
* @param {string} threadId - The ID of the thread to retrieve the messages from.
|
||||||
* @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread.
|
* @returns {Promise<ThreadMessage[]>} A promise that resolves to an array of messages from the thread.
|
||||||
*/
|
*/
|
||||||
getAllMessages(threadId: string): Promise<ThreadMessage[]>
|
listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deletes a specific message from a thread.
|
||||||
|
* @param {string} threadId - The ID of the thread from which the message will be deleted.
|
||||||
|
* @param {string} messageId - The ID of the message to be deleted.
|
||||||
|
* @returns {Promise<void>} A promise that resolves when the message has been successfully deleted.
|
||||||
|
*/
|
||||||
|
deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,15 +11,23 @@ export interface ThreadInterface {
|
|||||||
* @abstract
|
* @abstract
|
||||||
* @returns {Promise<Thread[]>} A promise that resolves to an array of threads.
|
* @returns {Promise<Thread[]>} A promise that resolves to an array of threads.
|
||||||
*/
|
*/
|
||||||
getThreads(): Promise<Thread[]>
|
listThreads(): Promise<Thread[]>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Saves a thread.
|
* Create a thread.
|
||||||
* @abstract
|
* @abstract
|
||||||
* @param {Thread} thread - The thread to save.
|
* @param {Thread} thread - The thread to save.
|
||||||
* @returns {Promise<void>} A promise that resolves when the thread is saved.
|
* @returns {Promise<void>} A promise that resolves when the thread is saved.
|
||||||
*/
|
*/
|
||||||
saveThread(thread: Thread): Promise<void>
|
createThread(thread: Thread): Promise<Thread>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* modify a thread.
|
||||||
|
* @abstract
|
||||||
|
* @param {Thread} thread - The thread to save.
|
||||||
|
* @returns {Promise<void>} A promise that resolves when the thread is saved.
|
||||||
|
*/
|
||||||
|
modifyThread(thread: Thread): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes a thread.
|
* Deletes a thread.
|
||||||
|
|||||||
@ -108,7 +108,7 @@ export const test = base.extend<
|
|||||||
})
|
})
|
||||||
|
|
||||||
test.beforeAll(async () => {
|
test.beforeAll(async () => {
|
||||||
await rmSync(path.join(__dirname, '../../test-data'), {
|
rmSync(path.join(__dirname, '../../test-data'), {
|
||||||
recursive: true,
|
recursive: true,
|
||||||
force: true,
|
force: true,
|
||||||
})
|
})
|
||||||
@ -122,6 +122,5 @@ test.beforeAll(async () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
test.afterAll(async () => {
|
test.afterAll(async () => {
|
||||||
// temporally disabling this due to the config for parallel testing WIP
|
|
||||||
// teardownElectron()
|
// teardownElectron()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -2,11 +2,8 @@ import { expect } from '@playwright/test'
|
|||||||
import { page, test, TIMEOUT } from '../config/fixtures'
|
import { page, test, TIMEOUT } from '../config/fixtures'
|
||||||
|
|
||||||
test('renders left navigation panel', async () => {
|
test('renders left navigation panel', async () => {
|
||||||
const settingsBtn = await page
|
const threadBtn = page.getByTestId('Thread').first()
|
||||||
.getByTestId('Thread')
|
await expect(threadBtn).toBeVisible({ timeout: TIMEOUT })
|
||||||
.first()
|
|
||||||
.isEnabled({ timeout: TIMEOUT })
|
|
||||||
expect([settingsBtn].filter((e) => !e).length).toBe(0)
|
|
||||||
// Chat section should be there
|
// Chat section should be there
|
||||||
await page.getByTestId('Local API Server').first().click({
|
await page.getByTestId('Local API Server').first().click({
|
||||||
timeout: TIMEOUT,
|
timeout: TIMEOUT,
|
||||||
|
|||||||
@ -141,7 +141,7 @@ export default class JanAssistantExtension extends AssistantExtension {
|
|||||||
top_k: 2,
|
top_k: 2,
|
||||||
chunk_size: 1024,
|
chunk_size: 1024,
|
||||||
chunk_overlap: 64,
|
chunk_overlap: 64,
|
||||||
retrieval_template: `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
retrieval_template: `Use the following pieces of context to answer the question at the end.
|
||||||
----------------
|
----------------
|
||||||
CONTEXT: {CONTEXT}
|
CONTEXT: {CONTEXT}
|
||||||
----------------
|
----------------
|
||||||
|
|||||||
@ -9,13 +9,14 @@ export function toolRetrievalUpdateTextSplitter(
|
|||||||
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
|
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
|
||||||
}
|
}
|
||||||
export async function toolRetrievalIngestNewDocument(
|
export async function toolRetrievalIngestNewDocument(
|
||||||
|
thread: string,
|
||||||
file: string,
|
file: string,
|
||||||
model: string,
|
model: string,
|
||||||
engine: string,
|
engine: string,
|
||||||
useTimeWeighted: boolean
|
useTimeWeighted: boolean
|
||||||
) {
|
) {
|
||||||
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
|
const threadPath = path.join(getJanDataFolderPath(), 'threads', thread)
|
||||||
const threadPath = path.dirname(filePath.replace('files', ''))
|
const filePath = path.join(getJanDataFolderPath(), 'files', file)
|
||||||
retrieval.updateEmbeddingEngine(model, engine)
|
retrieval.updateEmbeddingEngine(model, engine)
|
||||||
return retrieval
|
return retrieval
|
||||||
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
|
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
|
||||||
|
|||||||
@ -35,6 +35,7 @@ export class RetrievalTool extends InferenceTool {
|
|||||||
await executeOnMain(
|
await executeOnMain(
|
||||||
NODE,
|
NODE,
|
||||||
'toolRetrievalIngestNewDocument',
|
'toolRetrievalIngestNewDocument',
|
||||||
|
data.thread?.id,
|
||||||
docFile,
|
docFile,
|
||||||
data.model?.id,
|
data.model?.id,
|
||||||
data.model?.engine,
|
data.model?.engine,
|
||||||
|
|||||||
@ -18,12 +18,14 @@
|
|||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"cpx": "^1.5.0",
|
"cpx": "^1.5.0",
|
||||||
"rimraf": "^3.0.2",
|
"rimraf": "^3.0.2",
|
||||||
|
"ts-loader": "^9.5.0",
|
||||||
"webpack": "^5.88.2",
|
"webpack": "^5.88.2",
|
||||||
"webpack-cli": "^5.1.4",
|
"webpack-cli": "^5.1.4"
|
||||||
"ts-loader": "^9.5.0"
|
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@janhq/core": "file:../../core"
|
"@janhq/core": "file:../../core",
|
||||||
|
"ky": "^1.7.2",
|
||||||
|
"p-queue": "^8.0.1"
|
||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=18.0.0"
|
"node": ">=18.0.0"
|
||||||
|
|||||||
14
extensions/conversational-extension/src/@types/global.d.ts
vendored
Normal file
14
extensions/conversational-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
export {}
|
||||||
|
declare global {
|
||||||
|
declare const API_URL: string
|
||||||
|
declare const SOCKET_URL: string
|
||||||
|
|
||||||
|
interface Core {
|
||||||
|
api: APIFunctions
|
||||||
|
events: EventEmitter
|
||||||
|
}
|
||||||
|
interface Window {
|
||||||
|
core?: Core | undefined
|
||||||
|
electronAPI?: any | undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,408 +0,0 @@
|
|||||||
/**
|
|
||||||
* @jest-environment jsdom
|
|
||||||
*/
|
|
||||||
jest.mock('@janhq/core', () => ({
|
|
||||||
...jest.requireActual('@janhq/core/node'),
|
|
||||||
fs: {
|
|
||||||
existsSync: jest.fn(),
|
|
||||||
mkdir: jest.fn(),
|
|
||||||
writeFileSync: jest.fn(),
|
|
||||||
readdirSync: jest.fn(),
|
|
||||||
readFileSync: jest.fn(),
|
|
||||||
appendFileSync: jest.fn(),
|
|
||||||
rm: jest.fn(),
|
|
||||||
writeBlob: jest.fn(),
|
|
||||||
joinPath: jest.fn(),
|
|
||||||
fileStat: jest.fn(),
|
|
||||||
},
|
|
||||||
joinPath: jest.fn(),
|
|
||||||
ConversationalExtension: jest.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
import { fs } from '@janhq/core'
|
|
||||||
|
|
||||||
import JSONConversationalExtension from '.'
|
|
||||||
|
|
||||||
describe('JSONConversationalExtension Tests', () => {
|
|
||||||
let extension: JSONConversationalExtension
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
// @ts-ignore
|
|
||||||
extension = new JSONConversationalExtension()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should create thread folder on load if it does not exist', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.onLoad()
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should log message on unload', () => {
|
|
||||||
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
|
|
||||||
|
|
||||||
extension.onUnload()
|
|
||||||
|
|
||||||
expect(consoleSpy).toHaveBeenCalledWith(
|
|
||||||
'JSONConversationalExtension unloaded'
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return sorted threads', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'getValidThreadDirs')
|
|
||||||
.mockResolvedValue(['dir1', 'dir2'])
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'readThread')
|
|
||||||
.mockResolvedValueOnce({ updated: '2023-01-01' })
|
|
||||||
.mockResolvedValueOnce({ updated: '2023-01-02' })
|
|
||||||
|
|
||||||
const threads = await extension.getThreads()
|
|
||||||
|
|
||||||
expect(threads).toEqual([
|
|
||||||
{ updated: '2023-01-02' },
|
|
||||||
{ updated: '2023-01-01' },
|
|
||||||
])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should ignore broken threads', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'getValidThreadDirs')
|
|
||||||
.mockResolvedValue(['dir1', 'dir2'])
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'readThread')
|
|
||||||
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
|
|
||||||
.mockResolvedValueOnce('this_is_an_invalid_json_content')
|
|
||||||
|
|
||||||
const threads = await extension.getThreads()
|
|
||||||
|
|
||||||
expect(threads).toEqual([{ updated: '2023-01-01' }])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should save thread', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const writeFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'writeFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
|
|
||||||
const thread = { id: '1', updated: '2023-01-01' } as any
|
|
||||||
await extension.saveThread(thread)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should delete thread', async () => {
|
|
||||||
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.deleteThread('1')
|
|
||||||
|
|
||||||
expect(rmSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should add new message', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const appendFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'appendFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
|
|
||||||
const message = {
|
|
||||||
thread_id: '1',
|
|
||||||
content: [{ type: 'text', text: { annotations: [] } }],
|
|
||||||
} as any
|
|
||||||
await extension.addNewMessage(message)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should store image', async () => {
|
|
||||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.storeImage(
|
|
||||||
'data:image/png;base64,abcd',
|
|
||||||
'path/to/image.png'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(writeBlobSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should store file', async () => {
|
|
||||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.storeFile(
|
|
||||||
'data:application/pdf;base64,abcd',
|
|
||||||
'path/to/file.pdf'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(writeBlobSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should write messages', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const writeFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'writeFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
|
|
||||||
const messages = [{ id: '1', thread_id: '1', content: [] }] as any
|
|
||||||
await extension.writeMessages('1', messages)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should get all messages on string response', async () => {
|
|
||||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
|
||||||
jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
|
|
||||||
|
|
||||||
const messages = await extension.getAllMessages('1')
|
|
||||||
|
|
||||||
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should get all messages on object response', async () => {
|
|
||||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
|
||||||
jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
|
|
||||||
|
|
||||||
const messages = await extension.getAllMessages('1')
|
|
||||||
|
|
||||||
expect(messages).toEqual([{ id: 1 }])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('get all messages return empty on error', async () => {
|
|
||||||
jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
|
|
||||||
|
|
||||||
const messages = await extension.getAllMessages('1')
|
|
||||||
|
|
||||||
expect(messages).toEqual([])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('return empty messages on no messages file', async () => {
|
|
||||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
|
|
||||||
|
|
||||||
const messages = await extension.getAllMessages('1')
|
|
||||||
|
|
||||||
expect(messages).toEqual([])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should ignore error message', async () => {
|
|
||||||
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
|
|
||||||
jest
|
|
||||||
.spyOn(fs, 'readFileSync')
|
|
||||||
.mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
|
|
||||||
|
|
||||||
const messages = await extension.getAllMessages('1')
|
|
||||||
|
|
||||||
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should create thread folder on load if it does not exist', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.onLoad()
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should log message on unload', () => {
|
|
||||||
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
|
|
||||||
|
|
||||||
extension.onUnload()
|
|
||||||
|
|
||||||
expect(consoleSpy).toHaveBeenCalledWith(
|
|
||||||
'JSONConversationalExtension unloaded'
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should return sorted threads', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'getValidThreadDirs')
|
|
||||||
.mockResolvedValue(['dir1', 'dir2'])
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'readThread')
|
|
||||||
.mockResolvedValueOnce({ updated: '2023-01-01' })
|
|
||||||
.mockResolvedValueOnce({ updated: '2023-01-02' })
|
|
||||||
|
|
||||||
const threads = await extension.getThreads()
|
|
||||||
|
|
||||||
expect(threads).toEqual([
|
|
||||||
{ updated: '2023-01-02' },
|
|
||||||
{ updated: '2023-01-01' },
|
|
||||||
])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should ignore broken threads', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'getValidThreadDirs')
|
|
||||||
.mockResolvedValue(['dir1', 'dir2'])
|
|
||||||
jest
|
|
||||||
.spyOn(extension, 'readThread')
|
|
||||||
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
|
|
||||||
.mockResolvedValueOnce('this_is_an_invalid_json_content')
|
|
||||||
|
|
||||||
const threads = await extension.getThreads()
|
|
||||||
|
|
||||||
expect(threads).toEqual([{ updated: '2023-01-01' }])
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should save thread', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const writeFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'writeFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
|
|
||||||
const thread = { id: '1', updated: '2023-01-01' } as any
|
|
||||||
await extension.saveThread(thread)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(writeFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should delete thread', async () => {
|
|
||||||
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.deleteThread('1')
|
|
||||||
|
|
||||||
expect(rmSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should add new message', async () => {
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const appendFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'appendFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
|
|
||||||
const message = {
|
|
||||||
thread_id: '1',
|
|
||||||
content: [{ type: 'text', text: { annotations: [] } }],
|
|
||||||
} as any
|
|
||||||
await extension.addNewMessage(message)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should add new image message', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(fs, 'existsSync')
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(false)
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(false)
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(true)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const appendFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'appendFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
const message = {
|
|
||||||
thread_id: '1',
|
|
||||||
content: [
|
|
||||||
{ type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
|
|
||||||
],
|
|
||||||
} as any
|
|
||||||
await extension.addNewMessage(message)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should add new pdf message', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(fs, 'existsSync')
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(false)
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(false)
|
|
||||||
// @ts-ignore
|
|
||||||
.mockResolvedValueOnce(true)
|
|
||||||
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
|
|
||||||
const appendFileSyncSpy = jest
|
|
||||||
.spyOn(fs, 'appendFileSync')
|
|
||||||
.mockResolvedValue({})
|
|
||||||
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
const message = {
|
|
||||||
thread_id: '1',
|
|
||||||
content: [
|
|
||||||
{ type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
|
|
||||||
],
|
|
||||||
} as any
|
|
||||||
await extension.addNewMessage(message)
|
|
||||||
|
|
||||||
expect(mkdirSpy).toHaveBeenCalled()
|
|
||||||
expect(appendFileSyncSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should store image', async () => {
|
|
||||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.storeImage(
|
|
||||||
'data:image/png;base64,abcd',
|
|
||||||
'path/to/image.png'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(writeBlobSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should store file', async () => {
|
|
||||||
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
|
|
||||||
|
|
||||||
await extension.storeFile(
|
|
||||||
'data:application/pdf;base64,abcd',
|
|
||||||
'path/to/file.pdf'
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(writeBlobSpy).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('test readThread', () => {
|
|
||||||
let extension: JSONConversationalExtension
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
// @ts-ignore
|
|
||||||
extension = new JSONConversationalExtension()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should read thread', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(fs, 'readFileSync')
|
|
||||||
.mockResolvedValue(JSON.stringify({ id: '1' }))
|
|
||||||
const thread = await extension.readThread('1')
|
|
||||||
expect(thread).toEqual(`{"id":"1"}`)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('getValidThreadDirs should return valid thread directories', async () => {
|
|
||||||
jest
|
|
||||||
.spyOn(fs, 'readdirSync')
|
|
||||||
.mockResolvedValueOnce(['1', '2', '3'])
|
|
||||||
.mockResolvedValueOnce(['thread.json'])
|
|
||||||
.mockResolvedValueOnce(['thread.json'])
|
|
||||||
.mockResolvedValueOnce([])
|
|
||||||
// @ts-ignore
|
|
||||||
jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
|
|
||||||
jest.spyOn(fs, 'fileStat').mockResolvedValue({
|
|
||||||
isDirectory: true,
|
|
||||||
} as any)
|
|
||||||
const validThreadDirs = await extension.getValidThreadDirs()
|
|
||||||
expect(validThreadDirs).toEqual(['1', '2'])
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@ -1,90 +1,71 @@
|
|||||||
import {
|
import {
|
||||||
fs,
|
|
||||||
joinPath,
|
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
Thread,
|
Thread,
|
||||||
|
ThreadAssistantInfo,
|
||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { safelyParseJSON } from './jsonUtil'
|
import ky from 'ky'
|
||||||
|
import PQueue from 'p-queue'
|
||||||
|
|
||||||
|
type ThreadList = {
|
||||||
|
data: Thread[]
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageList = {
|
||||||
|
data: ThreadMessage[]
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
||||||
* functionality for managing threads.
|
* functionality for managing threads.
|
||||||
*/
|
*/
|
||||||
export default class JSONConversationalExtension extends ConversationalExtension {
|
export default class JSONConversationalExtension extends ConversationalExtension {
|
||||||
private static readonly _threadFolder = 'file://threads'
|
queue = new PQueue({ concurrency: 1 })
|
||||||
private static readonly _threadInfoFileName = 'thread.json'
|
|
||||||
private static readonly _threadMessagesFileName = 'messages.jsonl'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the extension is loaded.
|
* Called when the extension is loaded.
|
||||||
*/
|
*/
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) {
|
this.queue.add(() => this.healthz())
|
||||||
await fs.mkdir(JSONConversationalExtension._threadFolder)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the extension is unloaded.
|
* Called when the extension is unloaded.
|
||||||
*/
|
*/
|
||||||
onUnload() {
|
onUnload() {}
|
||||||
console.debug('JSONConversationalExtension unloaded')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a Promise that resolves to an array of Conversation objects.
|
* Returns a Promise that resolves to an array of Conversation objects.
|
||||||
*/
|
*/
|
||||||
async getThreads(): Promise<Thread[]> {
|
async listThreads(): Promise<Thread[]> {
|
||||||
try {
|
return this.queue.add(() =>
|
||||||
const threadDirs = await this.getValidThreadDirs()
|
ky
|
||||||
|
.get(`${API_URL}/v1/threads`)
|
||||||
const promises = threadDirs.map((dirName) => this.readThread(dirName))
|
.json<ThreadList>()
|
||||||
const promiseResults = await Promise.allSettled(promises)
|
.then((e) => e.data)
|
||||||
const convos = promiseResults
|
) as Promise<Thread[]>
|
||||||
.map((result) => {
|
|
||||||
if (result.status === 'fulfilled') {
|
|
||||||
return typeof result.value === 'object'
|
|
||||||
? result.value
|
|
||||||
: safelyParseJSON(result.value)
|
|
||||||
}
|
|
||||||
return undefined
|
|
||||||
})
|
|
||||||
.filter((convo) => !!convo)
|
|
||||||
convos.sort(
|
|
||||||
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
|
|
||||||
)
|
|
||||||
|
|
||||||
return convos
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error)
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Saves a Thread object to a json file.
|
* Saves a Thread object to a json file.
|
||||||
* @param thread The Thread object to save.
|
* @param thread The Thread object to save.
|
||||||
*/
|
*/
|
||||||
async saveThread(thread: Thread): Promise<void> {
|
async createThread(thread: Thread): Promise<Thread> {
|
||||||
try {
|
return this.queue.add(() =>
|
||||||
const threadDirPath = await joinPath([
|
ky.post(`${API_URL}/v1/threads`, { json: thread }).json<Thread>()
|
||||||
JSONConversationalExtension._threadFolder,
|
) as Promise<Thread>
|
||||||
thread.id,
|
|
||||||
])
|
|
||||||
const threadJsonPath = await joinPath([
|
|
||||||
threadDirPath,
|
|
||||||
JSONConversationalExtension._threadInfoFileName,
|
|
||||||
])
|
|
||||||
if (!(await fs.existsSync(threadDirPath))) {
|
|
||||||
await fs.mkdir(threadDirPath)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2))
|
/**
|
||||||
} catch (err) {
|
* Saves a Thread object to a json file.
|
||||||
console.error(err)
|
* @param thread The Thread object to save.
|
||||||
Promise.reject(err)
|
*/
|
||||||
}
|
async modifyThread(thread: Thread): Promise<void> {
|
||||||
|
return this.queue
|
||||||
|
.add(() =>
|
||||||
|
ky.post(`${API_URL}/v1/threads/${thread.id}`, { json: thread })
|
||||||
|
)
|
||||||
|
.then()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -92,189 +73,126 @@ export default class JSONConversationalExtension extends ConversationalExtension
|
|||||||
* @param threadId The ID of the thread to delete.
|
* @param threadId The ID of the thread to delete.
|
||||||
*/
|
*/
|
||||||
async deleteThread(threadId: string): Promise<void> {
|
async deleteThread(threadId: string): Promise<void> {
|
||||||
const path = await joinPath([
|
return this.queue
|
||||||
JSONConversationalExtension._threadFolder,
|
.add(() => ky.delete(`${API_URL}/v1/threads/${threadId}`))
|
||||||
`${threadId}`,
|
.then()
|
||||||
])
|
|
||||||
try {
|
|
||||||
await fs.rm(path)
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async addNewMessage(message: ThreadMessage): Promise<void> {
|
|
||||||
try {
|
|
||||||
const threadDirPath = await joinPath([
|
|
||||||
JSONConversationalExtension._threadFolder,
|
|
||||||
message.thread_id,
|
|
||||||
])
|
|
||||||
const threadMessagePath = await joinPath([
|
|
||||||
threadDirPath,
|
|
||||||
JSONConversationalExtension._threadMessagesFileName,
|
|
||||||
])
|
|
||||||
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
|
|
||||||
|
|
||||||
if (message.content[0]?.type === 'image') {
|
|
||||||
const filesPath = await joinPath([threadDirPath, 'files'])
|
|
||||||
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
|
|
||||||
|
|
||||||
const imagePath = await joinPath([filesPath, `${message.id}.png`])
|
|
||||||
const base64 = message.content[0].text.annotations[0]
|
|
||||||
await this.storeImage(base64, imagePath)
|
|
||||||
if ((await fs.existsSync(imagePath)) && message.content?.length) {
|
|
||||||
// Use file path instead of blob
|
|
||||||
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (message.content[0]?.type === 'pdf') {
|
|
||||||
const filesPath = await joinPath([threadDirPath, 'files'])
|
|
||||||
if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath)
|
|
||||||
|
|
||||||
const filePath = await joinPath([filesPath, `${message.id}.pdf`])
|
|
||||||
const blob = message.content[0].text.annotations[0]
|
|
||||||
await this.storeFile(blob, filePath)
|
|
||||||
|
|
||||||
if ((await fs.existsSync(filePath)) && message.content?.length) {
|
|
||||||
// Use file path instead of blob
|
|
||||||
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n')
|
|
||||||
Promise.resolve()
|
|
||||||
} catch (err) {
|
|
||||||
Promise.reject(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async storeImage(base64: string, filePath: string): Promise<void> {
|
|
||||||
const base64Data = base64.replace(/^data:image\/\w+;base64,/, '')
|
|
||||||
|
|
||||||
try {
|
|
||||||
await fs.writeBlob(filePath, base64Data)
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async storeFile(base64: string, filePath: string): Promise<void> {
|
|
||||||
const base64Data = base64.replace(/^data:application\/pdf;base64,/, '')
|
|
||||||
try {
|
|
||||||
await fs.writeBlob(filePath, base64Data)
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async writeMessages(
|
|
||||||
threadId: string,
|
|
||||||
messages: ThreadMessage[]
|
|
||||||
): Promise<void> {
|
|
||||||
try {
|
|
||||||
const threadDirPath = await joinPath([
|
|
||||||
JSONConversationalExtension._threadFolder,
|
|
||||||
threadId,
|
|
||||||
])
|
|
||||||
const threadMessagePath = await joinPath([
|
|
||||||
threadDirPath,
|
|
||||||
JSONConversationalExtension._threadMessagesFileName,
|
|
||||||
])
|
|
||||||
if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath)
|
|
||||||
await fs.writeFileSync(
|
|
||||||
threadMessagePath,
|
|
||||||
messages.map((msg) => JSON.stringify(msg)).join('\n') +
|
|
||||||
(messages.length ? '\n' : '')
|
|
||||||
)
|
|
||||||
Promise.resolve()
|
|
||||||
} catch (err) {
|
|
||||||
Promise.reject(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A promise builder for reading a thread from a file.
|
* Adds a new message to a specified thread.
|
||||||
* @param threadDirName the thread dir we are reading from.
|
* @param message The ThreadMessage object to be added.
|
||||||
* @returns data of the thread
|
* @returns A Promise that resolves when the message has been added.
|
||||||
*/
|
*/
|
||||||
async readThread(threadDirName: string): Promise<any> {
|
async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||||
return fs.readFileSync(
|
return this.queue.add(() =>
|
||||||
await joinPath([
|
ky
|
||||||
JSONConversationalExtension._threadFolder,
|
.post(`${API_URL}/v1/threads/${message.thread_id}/messages`, {
|
||||||
threadDirName,
|
json: message,
|
||||||
JSONConversationalExtension._threadInfoFileName,
|
|
||||||
]),
|
|
||||||
'utf-8'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a Promise that resolves to an array of thread directories.
|
|
||||||
* @private
|
|
||||||
*/
|
|
||||||
async getValidThreadDirs(): Promise<string[]> {
|
|
||||||
const fileInsideThread: string[] = await fs.readdirSync(
|
|
||||||
JSONConversationalExtension._threadFolder
|
|
||||||
)
|
|
||||||
|
|
||||||
const threadDirs: string[] = []
|
|
||||||
for (let i = 0; i < fileInsideThread.length; i++) {
|
|
||||||
const path = await joinPath([
|
|
||||||
JSONConversationalExtension._threadFolder,
|
|
||||||
fileInsideThread[i],
|
|
||||||
])
|
|
||||||
if (!(await fs.fileStat(path))?.isDirectory) continue
|
|
||||||
|
|
||||||
const isHavingThreadInfo = (await fs.readdirSync(path)).includes(
|
|
||||||
JSONConversationalExtension._threadInfoFileName
|
|
||||||
)
|
|
||||||
if (!isHavingThreadInfo) {
|
|
||||||
console.debug(`Ignore ${path} because it does not have thread info`)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
threadDirs.push(fileInsideThread[i])
|
|
||||||
}
|
|
||||||
return threadDirs
|
|
||||||
}
|
|
||||||
|
|
||||||
async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
|
|
||||||
try {
|
|
||||||
const threadDirPath = await joinPath([
|
|
||||||
JSONConversationalExtension._threadFolder,
|
|
||||||
threadId,
|
|
||||||
])
|
|
||||||
|
|
||||||
const files: string[] = await fs.readdirSync(threadDirPath)
|
|
||||||
if (
|
|
||||||
!files.includes(JSONConversationalExtension._threadMessagesFileName)
|
|
||||||
) {
|
|
||||||
console.debug(`${threadDirPath} not contains message file`)
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
|
|
||||||
const messageFilePath = await joinPath([
|
|
||||||
threadDirPath,
|
|
||||||
JSONConversationalExtension._threadMessagesFileName,
|
|
||||||
])
|
|
||||||
|
|
||||||
let readResult = await fs.readFileSync(messageFilePath, 'utf-8')
|
|
||||||
|
|
||||||
if (typeof readResult === 'object') {
|
|
||||||
readResult = JSON.stringify(readResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = readResult.split('\n').filter((line) => line !== '')
|
|
||||||
|
|
||||||
const messages: ThreadMessage[] = []
|
|
||||||
result.forEach((line: string) => {
|
|
||||||
const message = safelyParseJSON(line)
|
|
||||||
if (message) messages.push(safelyParseJSON(line))
|
|
||||||
})
|
})
|
||||||
return messages
|
.json<ThreadMessage>()
|
||||||
} catch (err) {
|
) as Promise<ThreadMessage>
|
||||||
console.error(err)
|
}
|
||||||
return []
|
|
||||||
}
|
/**
|
||||||
|
* Modifies a message in a thread.
|
||||||
|
* @param message
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||||
|
return this.queue.add(() =>
|
||||||
|
ky
|
||||||
|
.post(
|
||||||
|
`${API_URL}/v1/threads/${message.thread_id}/messages/${message.id}`,
|
||||||
|
{
|
||||||
|
json: message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.json<ThreadMessage>()
|
||||||
|
) as Promise<ThreadMessage>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deletes a specific message from a thread.
|
||||||
|
* @param threadId The ID of the thread containing the message.
|
||||||
|
* @param messageId The ID of the message to be deleted.
|
||||||
|
* @returns A Promise that resolves when the message has been successfully deleted.
|
||||||
|
*/
|
||||||
|
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
||||||
|
return this.queue
|
||||||
|
.add(() =>
|
||||||
|
ky.delete(`${API_URL}/v1/threads/${threadId}/messages/${messageId}`)
|
||||||
|
)
|
||||||
|
.then()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves all messages for a specified thread.
|
||||||
|
* @param threadId The ID of the thread to get messages from.
|
||||||
|
* @returns A Promise that resolves to an array of ThreadMessage objects.
|
||||||
|
*/
|
||||||
|
async listMessages(threadId: string): Promise<ThreadMessage[]> {
|
||||||
|
return this.queue.add(() =>
|
||||||
|
ky
|
||||||
|
.get(`${API_URL}/v1/threads/${threadId}/messages?order=asc`)
|
||||||
|
.json<MessageList>()
|
||||||
|
.then((e) => e.data)
|
||||||
|
) as Promise<ThreadMessage[]>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the assistant information for a specified thread.
|
||||||
|
* @param threadId The ID of the thread for which to retrieve assistant information.
|
||||||
|
* @returns A Promise that resolves to a ThreadAssistantInfo object containing
|
||||||
|
* the details of the assistant associated with the specified thread.
|
||||||
|
*/
|
||||||
|
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
|
||||||
|
return this.queue.add(() =>
|
||||||
|
ky.get(`${API_URL}/v1/assistants/${threadId}`).json<ThreadAssistantInfo>()
|
||||||
|
) as Promise<ThreadAssistantInfo>
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Creates a new assistant for the specified thread.
|
||||||
|
* @param threadId The ID of the thread for which the assistant is being created.
|
||||||
|
* @param assistant The information about the assistant to be created.
|
||||||
|
* @returns A Promise that resolves to the newly created ThreadAssistantInfo object.
|
||||||
|
*/
|
||||||
|
async createThreadAssistant(
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
): Promise<ThreadAssistantInfo> {
|
||||||
|
return this.queue.add(() =>
|
||||||
|
ky
|
||||||
|
.post(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
|
||||||
|
.json<ThreadAssistantInfo>()
|
||||||
|
) as Promise<ThreadAssistantInfo>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Modifies an existing assistant for the specified thread.
|
||||||
|
* @param threadId The ID of the thread for which the assistant is being modified.
|
||||||
|
* @param assistant The updated information for the assistant.
|
||||||
|
* @returns A Promise that resolves to the updated ThreadAssistantInfo object.
|
||||||
|
*/
|
||||||
|
async modifyThreadAssistant(
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
): Promise<ThreadAssistantInfo> {
|
||||||
|
return this.queue.add(() =>
|
||||||
|
ky
|
||||||
|
.patch(`${API_URL}/v1/assistants/${threadId}`, { json: assistant })
|
||||||
|
.json<ThreadAssistantInfo>()
|
||||||
|
) as Promise<ThreadAssistantInfo>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do health check on cortex.cpp
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
healthz(): Promise<void> {
|
||||||
|
return ky
|
||||||
|
.get(`${API_URL}/healthz`, {
|
||||||
|
retry: { limit: 20, delay: () => 500, methods: ['get'] },
|
||||||
|
})
|
||||||
|
.then(() => {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
// Note about performance
|
|
||||||
// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
|
|
||||||
// v8 4.5 and above can optimise try/catch
|
|
||||||
export function safelyParseJSON(json) {
|
|
||||||
// This function cannot be optimised, it's best to
|
|
||||||
// keep it small!
|
|
||||||
var parsed
|
|
||||||
try {
|
|
||||||
parsed = JSON.parse(json)
|
|
||||||
} catch (e) {
|
|
||||||
return undefined
|
|
||||||
}
|
|
||||||
return parsed // Could be undefined!
|
|
||||||
}
|
|
||||||
@ -17,7 +17,12 @@ module.exports = {
|
|||||||
filename: 'index.js', // Adjust the output file name as needed
|
filename: 'index.js', // Adjust the output file name as needed
|
||||||
library: { type: 'module' }, // Specify ESM output format
|
library: { type: 'module' }, // Specify ESM output format
|
||||||
},
|
},
|
||||||
plugins: [new webpack.DefinePlugin({})],
|
plugins: [
|
||||||
|
new webpack.DefinePlugin({
|
||||||
|
API_URL: JSON.stringify('http://127.0.0.1:39291'),
|
||||||
|
SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
|
||||||
|
}),
|
||||||
|
],
|
||||||
resolve: {
|
resolve: {
|
||||||
extensions: ['.ts', '.js'],
|
extensions: ['.ts', '.js'],
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
1.0.4
|
1.0.5-rc1
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
set BIN_PATH=./bin
|
set BIN_PATH=./bin
|
||||||
set SHARED_PATH=./../../electron/shared
|
set SHARED_PATH=./../../electron/shared
|
||||||
set /p CORTEX_VERSION=<./bin/version.txt
|
set /p CORTEX_VERSION=<./bin/version.txt
|
||||||
set ENGINE_VERSION=0.1.40
|
set ENGINE_VERSION=0.1.42
|
||||||
|
|
||||||
@REM Download cortex.llamacpp binaries
|
@REM Download cortex.llamacpp binaries
|
||||||
set DOWNLOAD_URL=https://github.com/janhq/cortex.llamacpp/releases/download/v%ENGINE_VERSION%/cortex.llamacpp-%ENGINE_VERSION%-windows-amd64
|
set DOWNLOAD_URL=https://github.com/janhq/cortex.llamacpp/releases/download/v%ENGINE_VERSION%/cortex.llamacpp-%ENGINE_VERSION%-windows-amd64
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Read CORTEX_VERSION
|
# Read CORTEX_VERSION
|
||||||
CORTEX_VERSION=$(cat ./bin/version.txt)
|
CORTEX_VERSION=$(cat ./bin/version.txt)
|
||||||
ENGINE_VERSION=0.1.40
|
ENGINE_VERSION=0.1.42
|
||||||
CORTEX_RELEASE_URL="https://github.com/janhq/cortex.cpp/releases/download"
|
CORTEX_RELEASE_URL="https://github.com/janhq/cortex.cpp/releases/download"
|
||||||
ENGINE_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}/cortex.llamacpp-${ENGINE_VERSION}"
|
ENGINE_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}/cortex.llamacpp-${ENGINE_VERSION}"
|
||||||
CUDA_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}"
|
CUDA_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}"
|
||||||
|
|||||||
@ -120,7 +120,7 @@ export default [
|
|||||||
SETTINGS: JSON.stringify(defaultSettingJson),
|
SETTINGS: JSON.stringify(defaultSettingJson),
|
||||||
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
|
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
|
||||||
CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
|
CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
|
||||||
CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'),
|
CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.42'),
|
||||||
}),
|
}),
|
||||||
// Allow json resolution
|
// Allow json resolution
|
||||||
json(),
|
json(),
|
||||||
|
|||||||
@ -18,14 +18,14 @@ import { isLocalEngine } from '@/utils/modelEngine'
|
|||||||
|
|
||||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||||
|
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
|
||||||
|
|
||||||
const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
||||||
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
||||||
const setMainState = useSetAtom(mainViewStateAtom)
|
const setMainState = useSetAtom(mainViewStateAtom)
|
||||||
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
|
|
||||||
const defaultDesc = () => {
|
const defaultDesc = () => {
|
||||||
return (
|
return (
|
||||||
@ -46,7 +46,7 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getEngine = () => {
|
const getEngine = () => {
|
||||||
const engineName = activeThread?.assistants?.[0]?.model?.engine
|
const engineName = activeAssistant?.model?.engine
|
||||||
return engineName ? EngineManager.instance().get(engineName) : null
|
return engineName ? EngineManager.instance().get(engineName) : null
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +89,9 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
|||||||
</span>
|
</span>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<AutoLink text={message.content[0].text.value} />
|
{message?.content[0]?.text?.value && (
|
||||||
|
<AutoLink text={message?.content[0]?.text?.value} />
|
||||||
|
)}
|
||||||
{defaultDesc()}
|
{defaultDesc()}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@ -46,6 +46,7 @@ import {
|
|||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
||||||
import {
|
import {
|
||||||
configuredModelsAtom,
|
configuredModelsAtom,
|
||||||
@ -75,6 +76,7 @@ const ModelDropdown = ({
|
|||||||
const [searchText, setSearchText] = useState('')
|
const [searchText, setSearchText] = useState('')
|
||||||
const [open, setOpen] = useState(false)
|
const [open, setOpen] = useState(false)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const downloadingModels = useAtomValue(getDownloadingModelAtom)
|
const downloadingModels = useAtomValue(getDownloadingModelAtom)
|
||||||
const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
|
const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
|
||||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||||
@ -151,17 +153,24 @@ const ModelDropdown = ({
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!activeThread) return
|
if (!activeThread) return
|
||||||
const modelId = activeThread?.assistants?.[0]?.model?.id
|
const modelId = activeAssistant?.model?.id
|
||||||
|
|
||||||
let model = downloadedModels.find((model) => model.id === modelId)
|
let model = downloadedModels.find((model) => model.id === modelId)
|
||||||
if (!model) {
|
if (!model) {
|
||||||
model = recommendedModel
|
model = recommendedModel
|
||||||
}
|
}
|
||||||
setSelectedModel(model)
|
setSelectedModel(model)
|
||||||
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
|
}, [
|
||||||
|
recommendedModel,
|
||||||
|
activeThread,
|
||||||
|
downloadedModels,
|
||||||
|
setSelectedModel,
|
||||||
|
activeAssistant?.model?.id,
|
||||||
|
])
|
||||||
|
|
||||||
const onClickModelItem = useCallback(
|
const onClickModelItem = useCallback(
|
||||||
async (modelId: string) => {
|
async (modelId: string) => {
|
||||||
|
if (!activeAssistant) return
|
||||||
const model = downloadedModels.find((m) => m.id === modelId)
|
const model = downloadedModels.find((m) => m.id === modelId)
|
||||||
setSelectedModel(model)
|
setSelectedModel(model)
|
||||||
setOpen(false)
|
setOpen(false)
|
||||||
@ -172,14 +181,14 @@ const ModelDropdown = ({
|
|||||||
...activeThread,
|
...activeThread,
|
||||||
assistants: [
|
assistants: [
|
||||||
{
|
{
|
||||||
...activeThread.assistants[0],
|
...activeAssistant,
|
||||||
tools: [
|
tools: [
|
||||||
{
|
{
|
||||||
type: 'retrieval',
|
type: 'retrieval',
|
||||||
enabled: isModelSupportRagAndTools(model as Model),
|
enabled: isModelSupportRagAndTools(model as Model),
|
||||||
settings: {
|
settings: {
|
||||||
...(activeThread.assistants[0].tools &&
|
...(activeAssistant.tools &&
|
||||||
activeThread.assistants[0].tools[0]?.settings),
|
activeAssistant.tools[0]?.settings),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -219,13 +228,14 @@ const ModelDropdown = ({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
|
activeAssistant,
|
||||||
downloadedModels,
|
downloadedModels,
|
||||||
activeThread,
|
|
||||||
setSelectedModel,
|
setSelectedModel,
|
||||||
|
activeThread,
|
||||||
|
updateThreadMetadata,
|
||||||
isModelSupportRagAndTools,
|
isModelSupportRagAndTools,
|
||||||
setThreadModelParams,
|
setThreadModelParams,
|
||||||
updateModelParameter,
|
updateModelParameter,
|
||||||
updateThreadMetadata,
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import { FileInfo } from '@/types/file'
|
|||||||
|
|
||||||
export const editPromptAtom = atom<string>('')
|
export const editPromptAtom = atom<string>('')
|
||||||
export const currentPromptAtom = atom<string>('')
|
export const currentPromptAtom = atom<string>('')
|
||||||
export const fileUploadAtom = atom<FileInfo[]>([])
|
export const fileUploadAtom = atom<FileInfo | undefined>()
|
||||||
|
|
||||||
export const searchAtom = atom<string>('')
|
export const searchAtom = atom<string>('')
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ import {
|
|||||||
addNewMessageAtom,
|
addNewMessageAtom,
|
||||||
updateMessageAtom,
|
updateMessageAtom,
|
||||||
tokenSpeedAtom,
|
tokenSpeedAtom,
|
||||||
|
deleteMessageAtom,
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
@ -49,6 +50,7 @@ export default function ModelHandler() {
|
|||||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||||
const updateMessage = useSetAtom(updateMessageAtom)
|
const updateMessage = useSetAtom(updateMessageAtom)
|
||||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||||
|
const deleteMessage = useSetAtom(deleteMessageAtom)
|
||||||
const activeModel = useAtomValue(activeModelAtom)
|
const activeModel = useAtomValue(activeModelAtom)
|
||||||
const setActiveModel = useSetAtom(activeModelAtom)
|
const setActiveModel = useSetAtom(activeModelAtom)
|
||||||
const setStateModel = useSetAtom(stateModelAtom)
|
const setStateModel = useSetAtom(stateModelAtom)
|
||||||
@ -86,7 +88,7 @@ export default function ModelHandler() {
|
|||||||
}, [activeModelParams])
|
}, [activeModelParams])
|
||||||
|
|
||||||
const onNewMessageResponse = useCallback(
|
const onNewMessageResponse = useCallback(
|
||||||
(message: ThreadMessage) => {
|
async (message: ThreadMessage) => {
|
||||||
if (message.type === MessageRequestType.Thread) {
|
if (message.type === MessageRequestType.Thread) {
|
||||||
addNewMessage(message)
|
addNewMessage(message)
|
||||||
}
|
}
|
||||||
@ -154,12 +156,15 @@ export default function ModelHandler() {
|
|||||||
...thread,
|
...thread,
|
||||||
|
|
||||||
title: cleanedMessageContent,
|
title: cleanedMessageContent,
|
||||||
metadata: thread.metadata,
|
metadata: {
|
||||||
|
...thread.metadata,
|
||||||
|
title: cleanedMessageContent,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.saveThread({
|
?.modifyThread({
|
||||||
...updatedThread,
|
...updatedThread,
|
||||||
})
|
})
|
||||||
.then(() => {
|
.then(() => {
|
||||||
@ -233,7 +238,9 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
||||||
if (!thread) return
|
if (!thread) return
|
||||||
|
|
||||||
const messageContent = message.content[0]?.text?.value
|
const messageContent = message.content[0]?.text?.value
|
||||||
|
|
||||||
const metadata = {
|
const metadata = {
|
||||||
...thread.metadata,
|
...thread.metadata,
|
||||||
...(messageContent && { lastMessage: messageContent }),
|
...(messageContent && { lastMessage: messageContent }),
|
||||||
@ -246,15 +253,19 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.saveThread({
|
?.modifyThread({
|
||||||
...thread,
|
...thread,
|
||||||
metadata,
|
metadata,
|
||||||
})
|
})
|
||||||
|
;(async () => {
|
||||||
// If this is not the summary of the Thread, don't need to add it to the Thread
|
const updatedMessage = await extensionManager
|
||||||
extensionManager
|
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.addNewMessage(message)
|
?.createMessage(message)
|
||||||
|
if (updatedMessage) {
|
||||||
|
deleteMessage(message.id)
|
||||||
|
addNewMessage(updatedMessage)
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
// Attempt to generate the title of the Thread when needed
|
// Attempt to generate the title of the Thread when needed
|
||||||
generateThreadTitle(message, thread)
|
generateThreadTitle(message, thread)
|
||||||
@ -279,7 +290,9 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
||||||
// If this is the first ever prompt in the thread
|
// If this is the first ever prompt in the thread
|
||||||
if (thread.title?.trim() !== defaultThreadTitle) {
|
if (
|
||||||
|
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
|
||||||
|
) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,11 +305,14 @@ export default function ModelHandler() {
|
|||||||
const updatedThread: Thread = {
|
const updatedThread: Thread = {
|
||||||
...thread,
|
...thread,
|
||||||
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
||||||
metadata: thread.metadata,
|
metadata: {
|
||||||
|
...thread.metadata,
|
||||||
|
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
return extensionManager
|
return extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.saveThread({
|
?.modifyThread({
|
||||||
...updatedThread,
|
...updatedThread,
|
||||||
})
|
})
|
||||||
.then(() => {
|
.then(() => {
|
||||||
@ -313,7 +329,7 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
if (!threadMessages || threadMessages.length === 0) return
|
if (!threadMessages || threadMessages.length === 0) return
|
||||||
|
|
||||||
const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0].content[0].text.value}"`
|
const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0]?.content[0]?.text?.value}"`
|
||||||
|
|
||||||
// Prompt: Given this query from user {query}, return to me the summary in 10 words as the title
|
// Prompt: Given this query from user {query}, return to me the summary in 10 words as the title
|
||||||
const msgId = ulid()
|
const msgId = ulid()
|
||||||
@ -330,6 +346,7 @@ export default function ModelHandler() {
|
|||||||
id: msgId,
|
id: msgId,
|
||||||
threadId: message.thread_id,
|
threadId: message.thread_id,
|
||||||
type: MessageRequestType.Summary,
|
type: MessageRequestType.Summary,
|
||||||
|
attachments: [],
|
||||||
messages,
|
messages,
|
||||||
model: {
|
model: {
|
||||||
...activeModelRef.current,
|
...activeModelRef.current,
|
||||||
|
|||||||
@ -1,4 +1,12 @@
|
|||||||
import { Assistant } from '@janhq/core'
|
import { Assistant, ThreadAssistantInfo } from '@janhq/core'
|
||||||
import { atom } from 'jotai'
|
import { atom } from 'jotai'
|
||||||
|
import { atomWithStorage } from 'jotai/utils'
|
||||||
|
|
||||||
export const assistantsAtom = atom<Assistant[]>([])
|
export const assistantsAtom = atom<Assistant[]>([])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the current active assistant
|
||||||
|
*/
|
||||||
|
export const activeAssistantAtom = atomWithStorage<
|
||||||
|
ThreadAssistantInfo | undefined
|
||||||
|
>('activeAssistant', undefined, undefined, { getOnInit: true })
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import {
|
|||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom } from 'jotai'
|
import { atom } from 'jotai'
|
||||||
|
|
||||||
|
import { atomWithStorage } from 'jotai/utils'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
getActiveThreadIdAtom,
|
getActiveThreadIdAtom,
|
||||||
updateThreadStateLastMessageAtom,
|
updateThreadStateLastMessageAtom,
|
||||||
@ -13,15 +15,23 @@ import {
|
|||||||
|
|
||||||
import { TokenSpeed } from '@/types/token'
|
import { TokenSpeed } from '@/types/token'
|
||||||
|
|
||||||
|
const CHAT_MESSAGE_NAME = 'chatMessages'
|
||||||
/**
|
/**
|
||||||
* Stores all chat messages for all threads
|
* Stores all chat messages for all threads
|
||||||
*/
|
*/
|
||||||
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
export const chatMessages = atomWithStorage<Record<string, ThreadMessage[]>>(
|
||||||
|
CHAT_MESSAGE_NAME,
|
||||||
|
{},
|
||||||
|
undefined,
|
||||||
|
{ getOnInit: true }
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stores the status of the messages load for each thread
|
* Stores the status of the messages load for each thread
|
||||||
*/
|
*/
|
||||||
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
export const readyThreadsMessagesAtom = atomWithStorage<
|
||||||
|
Record<string, boolean>
|
||||||
|
>('currentThreadMessages', {}, undefined, { getOnInit: true })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Store the token speed for current message
|
* Store the token speed for current message
|
||||||
@ -34,6 +44,7 @@ export const getCurrentChatMessagesAtom = atom<ThreadMessage[]>((get) => {
|
|||||||
const activeThreadId = get(getActiveThreadIdAtom)
|
const activeThreadId = get(getActiveThreadIdAtom)
|
||||||
if (!activeThreadId) return []
|
if (!activeThreadId) return []
|
||||||
const messages = get(chatMessages)[activeThreadId]
|
const messages = get(chatMessages)[activeThreadId]
|
||||||
|
if (!Array.isArray(messages)) return []
|
||||||
return messages ?? []
|
return messages ?? []
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -58,9 +58,11 @@ describe('Model.atom.ts', () => {
|
|||||||
setAtom.current({ id: '1' } as any)
|
setAtom.current({ id: '1' } as any)
|
||||||
})
|
})
|
||||||
expect(getAtom.current).toEqual([{ id: '1' }])
|
expect(getAtom.current).toEqual([{ id: '1' }])
|
||||||
|
act(() => {
|
||||||
reset.current([])
|
reset.current([])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('removeDownloadingModelAtom', () => {
|
describe('removeDownloadingModelAtom', () => {
|
||||||
it('should remove downloading model', async () => {
|
it('should remove downloading model', async () => {
|
||||||
@ -83,9 +85,11 @@ describe('Model.atom.ts', () => {
|
|||||||
removeAtom.current('1')
|
removeAtom.current('1')
|
||||||
})
|
})
|
||||||
expect(getAtom.current).toEqual([])
|
expect(getAtom.current).toEqual([])
|
||||||
|
act(() => {
|
||||||
reset.current([])
|
reset.current([])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('removeDownloadedModelAtom', () => {
|
describe('removeDownloadedModelAtom', () => {
|
||||||
it('should remove downloaded model', async () => {
|
it('should remove downloaded model', async () => {
|
||||||
@ -113,9 +117,11 @@ describe('Model.atom.ts', () => {
|
|||||||
removeAtom.current('1')
|
removeAtom.current('1')
|
||||||
})
|
})
|
||||||
expect(getAtom.current).toEqual([])
|
expect(getAtom.current).toEqual([])
|
||||||
|
act(() => {
|
||||||
reset.current([])
|
reset.current([])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('importingModelAtom', () => {
|
describe('importingModelAtom', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import { toaster } from '@/containers/Toast'
|
|||||||
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
||||||
|
|
||||||
import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ export function useActiveModel() {
|
|||||||
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
||||||
const pendingModelLoad = useRef(false)
|
const pendingModelLoad = useRef(false)
|
||||||
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
|
|
||||||
const downloadedModelsRef = useRef<Model[]>([])
|
const downloadedModelsRef = useRef<Model[]>([])
|
||||||
|
|
||||||
@ -79,12 +81,12 @@ export function useActiveModel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Apply thread model settings
|
/// Apply thread model settings
|
||||||
if (activeThread?.assistants[0]?.model.id === modelId) {
|
if (activeAssistant?.model.id === modelId) {
|
||||||
model = {
|
model = {
|
||||||
...model,
|
...model,
|
||||||
settings: {
|
settings: {
|
||||||
...model.settings,
|
...model.settings,
|
||||||
...activeThread.assistants[0].model.settings,
|
...activeAssistant?.model.settings,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -67,7 +67,7 @@ describe('useCreateNewThread', () => {
|
|||||||
} as any)
|
} as any)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
|
expect(mockSetAtom).toHaveBeenCalledTimes(1)
|
||||||
expect(extensionManager.get).toHaveBeenCalled()
|
expect(extensionManager.get).toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ describe('useCreateNewThread', () => {
|
|||||||
await result.current.requestCreateNewThread({
|
await result.current.requestCreateNewThread({
|
||||||
id: 'assistant1',
|
id: 'assistant1',
|
||||||
name: 'Assistant 1',
|
name: 'Assistant 1',
|
||||||
instructions: "Hello Jan Assistant",
|
instructions: 'Hello Jan Assistant',
|
||||||
model: {
|
model: {
|
||||||
id: 'model1',
|
id: 'model1',
|
||||||
parameters: [],
|
parameters: [],
|
||||||
@ -113,16 +113,8 @@ describe('useCreateNewThread', () => {
|
|||||||
} as any)
|
} as any)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
|
expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
|
||||||
expect(extensionManager.get).toHaveBeenCalled()
|
expect(extensionManager.get).toHaveBeenCalled()
|
||||||
expect(mockSetAtom).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
expect.objectContaining({
|
|
||||||
assistants: expect.arrayContaining([
|
|
||||||
expect.objectContaining({ instructions: 'Hello Jan Assistant' }),
|
|
||||||
]),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should create a new thread with previous instructions', async () => {
|
it('should create a new thread with previous instructions', async () => {
|
||||||
@ -166,16 +158,8 @@ describe('useCreateNewThread', () => {
|
|||||||
} as any)
|
} as any)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
|
expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
|
||||||
expect(extensionManager.get).toHaveBeenCalled()
|
expect(extensionManager.get).toHaveBeenCalled()
|
||||||
expect(mockSetAtom).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
expect.objectContaining({
|
|
||||||
assistants: expect.arrayContaining([
|
|
||||||
expect.objectContaining({ instructions: 'Hello Jan' }),
|
|
||||||
]),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should show a warning toast if trying to create an empty thread', async () => {
|
it('should show a warning toast if trying to create an empty thread', async () => {
|
||||||
@ -212,13 +196,12 @@ describe('useCreateNewThread', () => {
|
|||||||
|
|
||||||
const { result } = renderHook(() => useCreateNewThread())
|
const { result } = renderHook(() => useCreateNewThread())
|
||||||
|
|
||||||
const mockThread = { id: 'thread1', title: 'Test Thread' }
|
const mockThread = { id: 'thread1', title: 'Test Thread', assistants: [{}] }
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await result.current.updateThreadMetadata(mockThread as any)
|
await result.current.updateThreadMetadata(mockThread as any)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockUpdateThread).toHaveBeenCalledWith(mockThread)
|
expect(mockUpdateThread).toHaveBeenCalledWith(mockThread)
|
||||||
expect(extensionManager.get).toHaveBeenCalled()
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Assistant,
|
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
ExtensionTypeEnum,
|
ExtensionTypeEnum,
|
||||||
Thread,
|
Thread,
|
||||||
@ -9,8 +8,11 @@ import {
|
|||||||
ThreadState,
|
ThreadState,
|
||||||
AssistantTool,
|
AssistantTool,
|
||||||
Model,
|
Model,
|
||||||
|
Assistant,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
|
import { useDebouncedCallback } from 'use-debounce'
|
||||||
|
|
||||||
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
|
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
|
||||||
import { fileUploadAtom } from '@/containers/Providers/Jotai'
|
import { fileUploadAtom } from '@/containers/Providers/Jotai'
|
||||||
@ -18,7 +20,6 @@ import { fileUploadAtom } from '@/containers/Providers/Jotai'
|
|||||||
import { toaster } from '@/containers/Toast'
|
import { toaster } from '@/containers/Toast'
|
||||||
|
|
||||||
import { isLocalEngine } from '@/utils/modelEngine'
|
import { isLocalEngine } from '@/utils/modelEngine'
|
||||||
import { generateThreadId } from '@/utils/thread'
|
|
||||||
|
|
||||||
import { useActiveModel } from './useActiveModel'
|
import { useActiveModel } from './useActiveModel'
|
||||||
import useRecommendedModel from './useRecommendedModel'
|
import useRecommendedModel from './useRecommendedModel'
|
||||||
@ -28,6 +29,7 @@ import useSetActiveThread from './useSetActiveThread'
|
|||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
threadsAtom,
|
threadsAtom,
|
||||||
@ -35,7 +37,6 @@ import {
|
|||||||
updateThreadAtom,
|
updateThreadAtom,
|
||||||
setThreadModelParamsAtom,
|
setThreadModelParamsAtom,
|
||||||
isGeneratingResponseAtom,
|
isGeneratingResponseAtom,
|
||||||
activeThreadAtom,
|
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||||
@ -65,7 +66,7 @@ export const useCreateNewThread = () => {
|
|||||||
const copyOverInstructionEnabled = useAtomValue(
|
const copyOverInstructionEnabled = useAtomValue(
|
||||||
copyOverInstructionEnabledAtom
|
copyOverInstructionEnabledAtom
|
||||||
)
|
)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
|
||||||
|
|
||||||
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
||||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||||
@ -76,7 +77,7 @@ export const useCreateNewThread = () => {
|
|||||||
const { stopInference } = useActiveModel()
|
const { stopInference } = useActiveModel()
|
||||||
|
|
||||||
const requestCreateNewThread = async (
|
const requestCreateNewThread = async (
|
||||||
assistant: Assistant,
|
assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant,
|
||||||
model?: Model | undefined
|
model?: Model | undefined
|
||||||
) => {
|
) => {
|
||||||
// Stop generating if any
|
// Stop generating if any
|
||||||
@ -127,7 +128,7 @@ export const useCreateNewThread = () => {
|
|||||||
const createdAt = Date.now()
|
const createdAt = Date.now()
|
||||||
let instructions: string | undefined = assistant.instructions
|
let instructions: string | undefined = assistant.instructions
|
||||||
if (copyOverInstructionEnabled) {
|
if (copyOverInstructionEnabled) {
|
||||||
instructions = activeThread?.assistants[0]?.instructions ?? undefined
|
instructions = activeAssistant?.instructions ?? undefined
|
||||||
}
|
}
|
||||||
const assistantInfo: ThreadAssistantInfo = {
|
const assistantInfo: ThreadAssistantInfo = {
|
||||||
assistant_id: assistant.id,
|
assistant_id: assistant.id,
|
||||||
@ -142,45 +143,94 @@ export const useCreateNewThread = () => {
|
|||||||
instructions,
|
instructions,
|
||||||
}
|
}
|
||||||
|
|
||||||
const threadId = generateThreadId(assistant.id)
|
const thread: Partial<Thread> = {
|
||||||
const thread: Thread = {
|
|
||||||
id: threadId,
|
|
||||||
object: 'thread',
|
object: 'thread',
|
||||||
title: 'New Thread',
|
title: 'New Thread',
|
||||||
assistants: [assistantInfo],
|
assistants: [assistantInfo],
|
||||||
created: createdAt,
|
created: createdAt,
|
||||||
updated: createdAt,
|
updated: createdAt,
|
||||||
|
metadata: {
|
||||||
|
title: 'New Thread',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the new thread on top of the thread list to the state
|
// add the new thread on top of the thread list to the state
|
||||||
//TODO: Why do we have thread list then thread states? Should combine them
|
//TODO: Why do we have thread list then thread states? Should combine them
|
||||||
createNewThread(thread)
|
try {
|
||||||
|
const createdThread = await persistNewThread(thread, assistantInfo)
|
||||||
|
if (!createdThread) throw 'Thread created failed.'
|
||||||
|
createNewThread(createdThread)
|
||||||
|
|
||||||
setSelectedModel(defaultModel)
|
setSelectedModel(defaultModel)
|
||||||
setThreadModelParams(thread.id, {
|
setThreadModelParams(createdThread.id, {
|
||||||
...defaultModel?.settings,
|
...defaultModel?.settings,
|
||||||
...defaultModel?.parameters,
|
...defaultModel?.parameters,
|
||||||
...overriddenSettings,
|
...overriddenSettings,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Delete the file upload state
|
// Delete the file upload state
|
||||||
setFileUpload([])
|
setFileUpload(undefined)
|
||||||
// Update thread metadata
|
setActiveThread(createdThread)
|
||||||
await updateThreadMetadata(thread)
|
} catch (ex) {
|
||||||
|
return toaster({
|
||||||
setActiveThread(thread)
|
title: 'Thread created failed.',
|
||||||
|
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
|
||||||
|
type: 'error',
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateThreadExtension = (thread: Thread) => {
|
||||||
|
return extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.modifyThread(thread)
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateAssistantExtension = (
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
) => {
|
||||||
|
return extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.modifyThreadAssistant(threadId, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300)
|
||||||
|
const updateAssistantCallback = useDebouncedCallback(
|
||||||
|
updateAssistantExtension,
|
||||||
|
300
|
||||||
|
)
|
||||||
|
|
||||||
const updateThreadMetadata = useCallback(
|
const updateThreadMetadata = useCallback(
|
||||||
async (thread: Thread) => {
|
async (thread: Thread) => {
|
||||||
updateThread(thread)
|
updateThread(thread)
|
||||||
|
|
||||||
|
setActiveAssistant(thread.assistants[0])
|
||||||
|
updateThreadCallback(thread)
|
||||||
|
updateAssistantCallback(thread.id, thread.assistants[0])
|
||||||
|
},
|
||||||
|
[
|
||||||
|
updateThread,
|
||||||
|
setActiveAssistant,
|
||||||
|
updateThreadCallback,
|
||||||
|
updateAssistantCallback,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
const persistNewThread = async (
|
||||||
|
thread: Partial<Thread>,
|
||||||
|
assistantInfo: ThreadAssistantInfo
|
||||||
|
): Promise<Thread | undefined> => {
|
||||||
|
return await extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.createThread(thread)
|
||||||
|
.then(async (thread) => {
|
||||||
await extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.saveThread(thread)
|
?.createThreadAssistant(thread.id, assistantInfo)
|
||||||
},
|
return thread
|
||||||
[updateThread]
|
})
|
||||||
)
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
requestCreateNewThread,
|
requestCreateNewThread,
|
||||||
|
|||||||
@ -2,8 +2,7 @@ import { renderHook, act } from '@testing-library/react'
|
|||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
import useDeleteThread from './useDeleteThread'
|
import useDeleteThread from './useDeleteThread'
|
||||||
import { extensionManager } from '@/extension/ExtensionManager'
|
import { extensionManager } from '@/extension/ExtensionManager'
|
||||||
import { toaster } from '@/containers/Toast'
|
import { useCreateNewThread } from './useCreateNewThread'
|
||||||
|
|
||||||
// Mock the necessary dependencies
|
// Mock the necessary dependencies
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
jest.mock('jotai', () => ({
|
jest.mock('jotai', () => ({
|
||||||
@ -12,6 +11,7 @@ jest.mock('jotai', () => ({
|
|||||||
useAtom: jest.fn(),
|
useAtom: jest.fn(),
|
||||||
atom: jest.fn(),
|
atom: jest.fn(),
|
||||||
}))
|
}))
|
||||||
|
jest.mock('./useCreateNewThread')
|
||||||
jest.mock('@/extension/ExtensionManager')
|
jest.mock('@/extension/ExtensionManager')
|
||||||
jest.mock('@/containers/Toast')
|
jest.mock('@/containers/Toast')
|
||||||
|
|
||||||
@ -27,8 +27,13 @@ describe('useDeleteThread', () => {
|
|||||||
]
|
]
|
||||||
const mockSetThreads = jest.fn()
|
const mockSetThreads = jest.fn()
|
||||||
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
|
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
|
||||||
|
;(useSetAtom as jest.Mock).mockReturnValue(() => {})
|
||||||
|
;(useCreateNewThread as jest.Mock).mockReturnValue({})
|
||||||
|
|
||||||
|
const mockDeleteThread = jest.fn().mockImplementation(() => ({
|
||||||
|
catch: () => jest.fn,
|
||||||
|
}))
|
||||||
|
|
||||||
const mockDeleteThread = jest.fn()
|
|
||||||
extensionManager.get = jest.fn().mockReturnValue({
|
extensionManager.get = jest.fn().mockReturnValue({
|
||||||
deleteThread: mockDeleteThread,
|
deleteThread: mockDeleteThread,
|
||||||
})
|
})
|
||||||
@ -50,12 +55,17 @@ describe('useDeleteThread', () => {
|
|||||||
const mockCleanMessages = jest.fn()
|
const mockCleanMessages = jest.fn()
|
||||||
;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
|
;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
|
||||||
;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
|
;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
|
||||||
|
const mockCreateNewThread = jest.fn()
|
||||||
|
;(useCreateNewThread as jest.Mock).mockReturnValue({
|
||||||
|
requestCreateNewThread: mockCreateNewThread,
|
||||||
|
})
|
||||||
|
|
||||||
const mockWriteMessages = jest.fn()
|
|
||||||
const mockSaveThread = jest.fn()
|
const mockSaveThread = jest.fn()
|
||||||
|
const mockDeleteThread = jest.fn().mockResolvedValue({})
|
||||||
extensionManager.get = jest.fn().mockReturnValue({
|
extensionManager.get = jest.fn().mockReturnValue({
|
||||||
writeMessages: mockWriteMessages,
|
|
||||||
saveThread: mockSaveThread,
|
saveThread: mockSaveThread,
|
||||||
|
getThreadAssistant: jest.fn().mockResolvedValue({}),
|
||||||
|
deleteThread: mockDeleteThread,
|
||||||
})
|
})
|
||||||
|
|
||||||
const { result } = renderHook(() => useDeleteThread())
|
const { result } = renderHook(() => useDeleteThread())
|
||||||
@ -64,20 +74,18 @@ describe('useDeleteThread', () => {
|
|||||||
await result.current.cleanThread('thread1')
|
await result.current.cleanThread('thread1')
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockWriteMessages).toHaveBeenCalled()
|
expect(mockDeleteThread).toHaveBeenCalled()
|
||||||
expect(mockSaveThread).toHaveBeenCalledWith(
|
expect(mockCreateNewThread).toHaveBeenCalled()
|
||||||
expect.objectContaining({
|
|
||||||
id: 'thread1',
|
|
||||||
title: 'New Thread',
|
|
||||||
metadata: expect.objectContaining({ lastMessage: undefined }),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle errors when deleting a thread', async () => {
|
it('should handle errors when deleting a thread', async () => {
|
||||||
const mockThreads = [{ id: 'thread1', title: 'Thread 1' }]
|
const mockThreads = [{ id: 'thread1', title: 'Thread 1' }]
|
||||||
const mockSetThreads = jest.fn()
|
const mockSetThreads = jest.fn()
|
||||||
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
|
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
|
||||||
|
const mockCreateNewThread = jest.fn()
|
||||||
|
;(useCreateNewThread as jest.Mock).mockReturnValue({
|
||||||
|
requestCreateNewThread: mockCreateNewThread,
|
||||||
|
})
|
||||||
|
|
||||||
const mockDeleteThread = jest
|
const mockDeleteThread = jest
|
||||||
.fn()
|
.fn()
|
||||||
@ -98,8 +106,6 @@ describe('useDeleteThread', () => {
|
|||||||
|
|
||||||
expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
|
expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
|
||||||
expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error))
|
expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error))
|
||||||
expect(mockSetThreads).not.toHaveBeenCalled()
|
|
||||||
expect(toaster).not.toHaveBeenCalled()
|
|
||||||
|
|
||||||
consoleErrorSpy.mockRestore()
|
consoleErrorSpy.mockRestore()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,13 +1,6 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import {
|
import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core'
|
||||||
ChatCompletionRole,
|
|
||||||
ExtensionTypeEnum,
|
|
||||||
ConversationalExtension,
|
|
||||||
fs,
|
|
||||||
joinPath,
|
|
||||||
Thread,
|
|
||||||
} from '@janhq/core'
|
|
||||||
|
|
||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -15,89 +8,63 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai'
|
|||||||
|
|
||||||
import { toaster } from '@/containers/Toast'
|
import { toaster } from '@/containers/Toast'
|
||||||
|
|
||||||
|
import { useCreateNewThread } from './useCreateNewThread'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension/ExtensionManager'
|
import { extensionManager } from '@/extension/ExtensionManager'
|
||||||
|
|
||||||
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||||
chatMessages,
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
cleanChatMessageAtom as cleanChatMessagesAtom,
|
|
||||||
deleteChatMessageAtom as deleteChatMessagesAtom,
|
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
|
||||||
import {
|
import {
|
||||||
threadsAtom,
|
threadsAtom,
|
||||||
setActiveThreadIdAtom,
|
setActiveThreadIdAtom,
|
||||||
deleteThreadStateAtom,
|
deleteThreadStateAtom,
|
||||||
updateThreadStateLastMessageAtom,
|
|
||||||
updateThreadAtom,
|
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export default function useDeleteThread() {
|
export default function useDeleteThread() {
|
||||||
const [threads, setThreads] = useAtom(threadsAtom)
|
const [threads, setThreads] = useAtom(threadsAtom)
|
||||||
const messages = useAtomValue(chatMessages)
|
const { requestCreateNewThread } = useCreateNewThread()
|
||||||
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
const assistants = useAtomValue(assistantsAtom)
|
||||||
|
const models = useAtomValue(downloadedModelsAtom)
|
||||||
|
|
||||||
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
||||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||||
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
|
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
|
||||||
const cleanMessages = useSetAtom(cleanChatMessagesAtom)
|
|
||||||
|
|
||||||
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
|
const deleteThreadState = useSetAtom(deleteThreadStateAtom)
|
||||||
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)
|
|
||||||
const updateThread = useSetAtom(updateThreadAtom)
|
|
||||||
|
|
||||||
const cleanThread = useCallback(
|
const cleanThread = useCallback(
|
||||||
async (threadId: string) => {
|
async (threadId: string) => {
|
||||||
cleanMessages(threadId)
|
|
||||||
const thread = threads.find((c) => c.id === threadId)
|
const thread = threads.find((c) => c.id === threadId)
|
||||||
if (!thread) return
|
if (!thread) return
|
||||||
|
const assistantInfo = await extensionManager
|
||||||
const updatedMessages = (messages[threadId] ?? []).filter(
|
|
||||||
(msg) => msg.role === ChatCompletionRole.System
|
|
||||||
)
|
|
||||||
|
|
||||||
// remove files
|
|
||||||
try {
|
|
||||||
const threadFolderPath = await joinPath([
|
|
||||||
janDataFolderPath,
|
|
||||||
'threads',
|
|
||||||
threadId,
|
|
||||||
])
|
|
||||||
const threadFilesPath = await joinPath([threadFolderPath, 'files'])
|
|
||||||
const threadMemoryPath = await joinPath([threadFolderPath, 'memory'])
|
|
||||||
await fs.rm(threadFilesPath)
|
|
||||||
await fs.rm(threadMemoryPath)
|
|
||||||
} catch (err) {
|
|
||||||
console.warn('Error deleting thread files', err)
|
|
||||||
}
|
|
||||||
|
|
||||||
await extensionManager
|
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.writeMessages(threadId, updatedMessages)
|
?.getThreadAssistant(thread.id)
|
||||||
|
|
||||||
thread.metadata = {
|
if (!assistantInfo) return
|
||||||
...thread.metadata,
|
const model = models.find((c) => c.id === assistantInfo?.model?.id)
|
||||||
}
|
|
||||||
|
|
||||||
const updatedThread: Thread = {
|
requestCreateNewThread(
|
||||||
...thread,
|
{
|
||||||
title: 'New Thread',
|
...assistantInfo,
|
||||||
metadata: { ...thread.metadata, lastMessage: undefined },
|
id: assistants[0].id,
|
||||||
}
|
name: assistants[0].name,
|
||||||
|
|
||||||
await extensionManager
|
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
|
||||||
?.saveThread(updatedThread)
|
|
||||||
updateThreadLastMessage(threadId, undefined)
|
|
||||||
updateThread(updatedThread)
|
|
||||||
},
|
},
|
||||||
[
|
model
|
||||||
cleanMessages,
|
? {
|
||||||
threads,
|
...model,
|
||||||
messages,
|
parameters: assistantInfo?.model?.parameters ?? {},
|
||||||
updateThreadLastMessage,
|
settings: assistantInfo?.model?.settings ?? {},
|
||||||
updateThread,
|
}
|
||||||
janDataFolderPath,
|
: undefined
|
||||||
]
|
)
|
||||||
|
// Delete this thread
|
||||||
|
await extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.deleteThread(threadId)
|
||||||
|
.catch(console.error)
|
||||||
|
},
|
||||||
|
[assistants, models, requestCreateNewThread, threads]
|
||||||
)
|
)
|
||||||
|
|
||||||
const deleteThread = async (threadId: string) => {
|
const deleteThread = async (threadId: string) => {
|
||||||
@ -105,10 +72,10 @@ export default function useDeleteThread() {
|
|||||||
alert('No active thread')
|
alert('No active thread')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
try {
|
|
||||||
await extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.deleteThread(threadId)
|
?.deleteThread(threadId)
|
||||||
|
.catch(console.error)
|
||||||
const availableThreads = threads.filter((c) => c.id !== threadId)
|
const availableThreads = threads.filter((c) => c.id !== threadId)
|
||||||
setThreads(availableThreads)
|
setThreads(availableThreads)
|
||||||
|
|
||||||
@ -127,9 +94,6 @@ export default function useDeleteThread() {
|
|||||||
} else {
|
} else {
|
||||||
setActiveThreadId(undefined)
|
setActiveThreadId(undefined)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @jest-environment jsdom
|
||||||
|
*/
|
||||||
// useDropModelBinaries.test.ts
|
// useDropModelBinaries.test.ts
|
||||||
|
|
||||||
import { renderHook, act } from '@testing-library/react'
|
import { renderHook, act } from '@testing-library/react'
|
||||||
@ -18,6 +21,7 @@ jest.mock('jotai', () => ({
|
|||||||
jest.mock('uuid')
|
jest.mock('uuid')
|
||||||
jest.mock('@/utils/file')
|
jest.mock('@/utils/file')
|
||||||
jest.mock('@/containers/Toast')
|
jest.mock('@/containers/Toast')
|
||||||
|
jest.mock("@uppy/core")
|
||||||
|
|
||||||
describe('useDropModelBinaries', () => {
|
describe('useDropModelBinaries', () => {
|
||||||
const mockSetImportingModels = jest.fn()
|
const mockSetImportingModels = jest.fn()
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import { openFileExplorer, joinPath, baseName } from '@janhq/core'
|
|||||||
import { useAtomValue } from 'jotai'
|
import { useAtomValue } from 'jotai'
|
||||||
|
|
||||||
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -9,13 +10,14 @@ export const usePath = () => {
|
|||||||
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
const selectedModel = useAtomValue(selectedModelAtom)
|
const selectedModel = useAtomValue(selectedModelAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
|
|
||||||
const onRevealInFinder = async (type: string) => {
|
const onRevealInFinder = async (type: string) => {
|
||||||
// TODO: this logic should be refactored.
|
// TODO: this logic should be refactored.
|
||||||
if (type !== 'Model' && !activeThread) return
|
if (type !== 'Model' && !activeThread) return
|
||||||
|
|
||||||
let filePath = undefined
|
let filePath = undefined
|
||||||
const assistantId = activeThread?.assistants[0]?.assistant_id
|
const assistantId = activeAssistant?.assistant_id
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 'Engine':
|
case 'Engine':
|
||||||
case 'Thread':
|
case 'Thread':
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import { atom, useAtomValue } from 'jotai'
|
|||||||
|
|
||||||
import { activeModelAtom } from './useActiveModel'
|
import { activeModelAtom } from './useActiveModel'
|
||||||
|
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ export default function useRecommendedModel() {
|
|||||||
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
|
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
|
|
||||||
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
|
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
|
||||||
const models = downloadedModels.sort((a, b) =>
|
const models = downloadedModels.sort((a, b) =>
|
||||||
@ -45,8 +47,8 @@ export default function useRecommendedModel() {
|
|||||||
> => {
|
> => {
|
||||||
const models = await getAndSortDownloadedModels()
|
const models = await getAndSortDownloadedModels()
|
||||||
|
|
||||||
if (!activeThread) return
|
if (!activeThread || !activeAssistant) return
|
||||||
const modelId = activeThread.assistants[0]?.model.id
|
const modelId = activeAssistant.model.id
|
||||||
const model = models.find((model) => model.id === modelId)
|
const model = models.find((model) => model.id === modelId)
|
||||||
|
|
||||||
if (model) {
|
if (model) {
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import {
|
|||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
EngineManager,
|
EngineManager,
|
||||||
ToolManager,
|
ToolManager,
|
||||||
|
ThreadAssistantInfo,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
@ -28,6 +29,7 @@ import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
|||||||
import { useActiveModel } from './useActiveModel'
|
import { useActiveModel } from './useActiveModel'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension/ExtensionManager'
|
import { extensionManager } from '@/extension/ExtensionManager'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import {
|
||||||
addNewMessageAtom,
|
addNewMessageAtom,
|
||||||
deleteMessageAtom,
|
deleteMessageAtom,
|
||||||
@ -48,6 +50,7 @@ export const reloadModelAtom = atom(false)
|
|||||||
|
|
||||||
export default function useSendChatMessage() {
|
export default function useSendChatMessage() {
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||||
const updateThread = useSetAtom(updateThreadAtom)
|
const updateThread = useSetAtom(updateThreadAtom)
|
||||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||||
@ -68,6 +71,7 @@ export default function useSendChatMessage() {
|
|||||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||||
const activeThreadRef = useRef<Thread | undefined>()
|
const activeThreadRef = useRef<Thread | undefined>()
|
||||||
|
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
|
||||||
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||||
|
|
||||||
const selectedModelRef = useRef<Model | undefined>()
|
const selectedModelRef = useRef<Model | undefined>()
|
||||||
@ -84,36 +88,37 @@ export default function useSendChatMessage() {
|
|||||||
selectedModelRef.current = selectedModel
|
selectedModelRef.current = selectedModel
|
||||||
}, [selectedModel])
|
}, [selectedModel])
|
||||||
|
|
||||||
const resendChatMessage = async (currentMessage: ThreadMessage) => {
|
useEffect(() => {
|
||||||
|
activeAssistantRef.current = activeAssistant
|
||||||
|
}, [activeAssistant])
|
||||||
|
|
||||||
|
const resendChatMessage = async () => {
|
||||||
// Delete last response before regenerating
|
// Delete last response before regenerating
|
||||||
const newConvoData = currentMessages
|
const newConvoData = Array.from(currentMessages)
|
||||||
let toSendMessage = currentMessage
|
let toSendMessage = newConvoData.pop()
|
||||||
|
|
||||||
do {
|
while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) {
|
||||||
deleteMessage(currentMessage.id)
|
|
||||||
const msg = newConvoData.pop()
|
|
||||||
if (!msg) break
|
|
||||||
toSendMessage = msg
|
|
||||||
deleteMessage(toSendMessage.id ?? '')
|
|
||||||
} while (toSendMessage.role !== ChatCompletionRole.User)
|
|
||||||
|
|
||||||
if (activeThreadRef.current) {
|
|
||||||
await extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.writeMessages(activeThreadRef.current.id, newConvoData)
|
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
|
||||||
|
.catch(console.error)
|
||||||
|
deleteMessage(toSendMessage.id ?? '')
|
||||||
|
toSendMessage = newConvoData.pop()
|
||||||
}
|
}
|
||||||
|
|
||||||
sendChatMessage(toSendMessage.content[0]?.text.value)
|
if (toSendMessage?.content[0]?.text?.value)
|
||||||
|
sendChatMessage(toSendMessage.content[0].text.value, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
const sendChatMessage = async (
|
const sendChatMessage = async (
|
||||||
message: string,
|
message: string,
|
||||||
|
isResend: boolean = false,
|
||||||
messages?: ThreadMessage[]
|
messages?: ThreadMessage[]
|
||||||
) => {
|
) => {
|
||||||
if (!message || message.trim().length === 0) return
|
if (!message || message.trim().length === 0) return
|
||||||
|
|
||||||
if (!activeThreadRef.current) {
|
if (!activeThreadRef.current || !activeAssistantRef.current) {
|
||||||
console.error('No active thread')
|
console.error('No active thread or assistant')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,21 +134,19 @@ export default function useSendChatMessage() {
|
|||||||
setCurrentPrompt('')
|
setCurrentPrompt('')
|
||||||
setEditPrompt('')
|
setEditPrompt('')
|
||||||
|
|
||||||
let base64Blob = fileUpload[0]
|
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
|
||||||
? await getBase64(fileUpload[0].file)
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
if (base64Blob && fileUpload[0]?.type === 'image') {
|
if (base64Blob && fileUpload?.type === 'image') {
|
||||||
// Compress image
|
// Compress image
|
||||||
base64Blob = await compressImage(base64Blob, 512)
|
base64Blob = await compressImage(base64Blob, 512)
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelRequest =
|
const modelRequest =
|
||||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
selectedModelRef?.current ?? activeAssistantRef.current?.model
|
||||||
|
|
||||||
// Fallback support for previous broken threads
|
// Fallback support for previous broken threads
|
||||||
if (activeThreadRef.current?.assistants[0]?.model?.id === '*') {
|
if (activeAssistantRef.current?.model?.id === '*') {
|
||||||
activeThreadRef.current.assistants[0].model = {
|
activeAssistantRef.current.model = {
|
||||||
id: modelRequest.id,
|
id: modelRequest.id,
|
||||||
settings: modelRequest.settings,
|
settings: modelRequest.settings,
|
||||||
parameters: modelRequest.parameters,
|
parameters: modelRequest.parameters,
|
||||||
@ -163,9 +166,10 @@ export default function useSendChatMessage() {
|
|||||||
},
|
},
|
||||||
activeThreadRef.current,
|
activeThreadRef.current,
|
||||||
messages ?? currentMessages
|
messages ?? currentMessages
|
||||||
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
|
).addSystemMessage(activeAssistantRef.current?.instructions)
|
||||||
|
|
||||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
|
if (!isResend) {
|
||||||
|
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
|
||||||
|
|
||||||
// Build Thread Message to persist
|
// Build Thread Message to persist
|
||||||
const threadMessageBuilder = new ThreadMessageBuilder(
|
const threadMessageBuilder = new ThreadMessageBuilder(
|
||||||
@ -174,9 +178,6 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
const newMessage = threadMessageBuilder.build()
|
const newMessage = threadMessageBuilder.build()
|
||||||
|
|
||||||
// Push to states
|
|
||||||
addNewMessage(newMessage)
|
|
||||||
|
|
||||||
// Update thread state
|
// Update thread state
|
||||||
const updatedThread: Thread = {
|
const updatedThread: Thread = {
|
||||||
...activeThreadRef.current,
|
...activeThreadRef.current,
|
||||||
@ -189,20 +190,25 @@ export default function useSendChatMessage() {
|
|||||||
updateThread(updatedThread)
|
updateThread(updatedThread)
|
||||||
|
|
||||||
// Add message
|
// Add message
|
||||||
await extensionManager
|
const createdMessage = await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.addNewMessage(newMessage)
|
?.createMessage(newMessage)
|
||||||
|
|
||||||
|
if (!createdMessage) return
|
||||||
|
|
||||||
|
// Push to states
|
||||||
|
addNewMessage(createdMessage)
|
||||||
|
}
|
||||||
|
|
||||||
// Start Model if not started
|
// Start Model if not started
|
||||||
const modelId =
|
const modelId =
|
||||||
selectedModelRef.current?.id ??
|
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
|
||||||
activeThreadRef.current.assistants[0].model.id
|
|
||||||
|
|
||||||
if (base64Blob) {
|
if (base64Blob) {
|
||||||
setFileUpload([])
|
setFileUpload(undefined)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId && modelId) {
|
||||||
const error = await startModel(modelId).catch((error: Error) => error)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
if (error) {
|
if (error) {
|
||||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||||
@ -214,9 +220,7 @@ export default function useSendChatMessage() {
|
|||||||
// Process message request with Assistants tools
|
// Process message request with Assistants tools
|
||||||
const request = await ToolManager.instance().process(
|
const request = await ToolManager.instance().process(
|
||||||
requestBuilder.build(),
|
requestBuilder.build(),
|
||||||
activeThreadRef.current.assistants?.flatMap(
|
activeAssistantRef?.current.tools ?? []
|
||||||
(assistant) => assistant.tools ?? []
|
|
||||||
) ?? []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Request for inference
|
// Request for inference
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
|
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
|
||||||
|
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
readyThreadsMessagesAtom,
|
import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||||
setConvoMessagesAtom,
|
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
|
||||||
import {
|
import {
|
||||||
setActiveThreadIdAtom,
|
setActiveThreadIdAtom,
|
||||||
setThreadModelParamsAtom,
|
setThreadModelParamsAtom,
|
||||||
@ -17,21 +15,27 @@ export default function useSetActiveThread() {
|
|||||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||||
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
|
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
|
||||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||||
const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom)
|
const setActiveAssistant = useSetAtom(activeAssistantAtom)
|
||||||
|
|
||||||
const setActiveThread = async (thread: Thread) => {
|
const setActiveThread = async (thread: Thread) => {
|
||||||
// Load local messages only if there are no messages in the state
|
if (!thread?.id) return
|
||||||
if (!readyMessageThreads[thread?.id]) {
|
|
||||||
const messages = await getLocalThreadMessage(thread?.id)
|
|
||||||
setThreadMessage(thread?.id, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
setActiveThreadId(thread?.id)
|
setActiveThreadId(thread?.id)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const assistantInfo = await getThreadAssistant(thread.id)
|
||||||
|
setActiveAssistant(assistantInfo)
|
||||||
|
// Load local messages only if there are no messages in the state
|
||||||
|
const messages = await getLocalThreadMessage(thread.id).catch(() => [])
|
||||||
const modelParams: ModelParams = {
|
const modelParams: ModelParams = {
|
||||||
...thread?.assistants[0]?.model?.parameters,
|
...assistantInfo?.model?.parameters,
|
||||||
...thread?.assistants[0]?.model?.settings,
|
...assistantInfo?.model?.settings,
|
||||||
}
|
}
|
||||||
setThreadModelParams(thread?.id, modelParams)
|
setThreadModelParams(thread?.id, modelParams)
|
||||||
|
setThreadMessage(thread.id, messages)
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { setActiveThread }
|
return { setActiveThread }
|
||||||
@ -40,4 +44,9 @@ export default function useSetActiveThread() {
|
|||||||
const getLocalThreadMessage = async (threadId: string) =>
|
const getLocalThreadMessage = async (threadId: string) =>
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.getAllMessages(threadId) ?? []
|
?.listMessages(threadId) ?? []
|
||||||
|
|
||||||
|
const getThreadAssistant = async (threadId: string) =>
|
||||||
|
extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.getThreadAssistant(threadId)
|
||||||
|
|||||||
@ -78,7 +78,7 @@ describe('useThreads', () => {
|
|||||||
// Mock extensionManager
|
// Mock extensionManager
|
||||||
const mockGetThreads = jest.fn().mockResolvedValue(mockThreads)
|
const mockGetThreads = jest.fn().mockResolvedValue(mockThreads)
|
||||||
;(extensionManager.get as jest.Mock).mockReturnValue({
|
;(extensionManager.get as jest.Mock).mockReturnValue({
|
||||||
getThreads: mockGetThreads,
|
listThreads: mockGetThreads,
|
||||||
})
|
})
|
||||||
|
|
||||||
const { result } = renderHook(() => useThreads())
|
const { result } = renderHook(() => useThreads())
|
||||||
@ -119,7 +119,7 @@ describe('useThreads', () => {
|
|||||||
it('should handle empty threads', async () => {
|
it('should handle empty threads', async () => {
|
||||||
// Mock empty threads
|
// Mock empty threads
|
||||||
;(extensionManager.get as jest.Mock).mockReturnValue({
|
;(extensionManager.get as jest.Mock).mockReturnValue({
|
||||||
getThreads: jest.fn().mockResolvedValue([]),
|
listThreads: jest.fn().mockResolvedValue([]),
|
||||||
})
|
})
|
||||||
|
|
||||||
const mockSetThreadStates = jest.fn()
|
const mockSetThreadStates = jest.fn()
|
||||||
|
|||||||
@ -68,6 +68,6 @@ const useThreads = () => {
|
|||||||
const getLocalThreads = async (): Promise<Thread[]> =>
|
const getLocalThreads = async (): Promise<Thread[]> =>
|
||||||
(await extensionManager
|
(await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.getThreads()) ?? []
|
?.listThreads()) ?? []
|
||||||
|
|
||||||
export default useThreads
|
export default useThreads
|
||||||
|
|||||||
@ -1,7 +1,12 @@
|
|||||||
import { renderHook, act } from '@testing-library/react'
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
import { useAtom } from 'jotai'
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
jest.mock('ulidx')
|
jest.mock('ulidx')
|
||||||
jest.mock('@/extension')
|
jest.mock('@/extension')
|
||||||
|
jest.mock('jotai', () => ({
|
||||||
|
...jest.requireActual('jotai'),
|
||||||
|
useAtom: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
import useUpdateModelParameters from './useUpdateModelParameters'
|
import useUpdateModelParameters from './useUpdateModelParameters'
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
@ -13,7 +18,8 @@ let model: any = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let extension: any = {
|
let extension: any = {
|
||||||
saveThread: jest.fn(),
|
modifyThread: jest.fn(),
|
||||||
|
modifyThreadAssistant: jest.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
const mockThread: any = {
|
const mockThread: any = {
|
||||||
@ -35,6 +41,7 @@ const mockThread: any = {
|
|||||||
describe('useUpdateModelParameters', () => {
|
describe('useUpdateModelParameters', () => {
|
||||||
beforeAll(() => {
|
beforeAll(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
|
jest.useFakeTimers()
|
||||||
jest.mock('./useRecommendedModel', () => ({
|
jest.mock('./useRecommendedModel', () => ({
|
||||||
useRecommendedModel: () => ({
|
useRecommendedModel: () => ({
|
||||||
recommendedModel: model,
|
recommendedModel: model,
|
||||||
@ -45,6 +52,12 @@ describe('useUpdateModelParameters', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should update model parameters and save thread when params are valid', async () => {
|
it('should update model parameters and save thread when params are valid', async () => {
|
||||||
|
;(useAtom as jest.Mock).mockReturnValue([
|
||||||
|
{
|
||||||
|
id: 'assistant-1',
|
||||||
|
},
|
||||||
|
jest.fn(),
|
||||||
|
])
|
||||||
const mockValidParameters: any = {
|
const mockValidParameters: any = {
|
||||||
params: {
|
params: {
|
||||||
// Inference
|
// Inference
|
||||||
@ -76,7 +89,8 @@ describe('useUpdateModelParameters', () => {
|
|||||||
|
|
||||||
// Spy functions
|
// Spy functions
|
||||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
|
||||||
|
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
|
||||||
|
|
||||||
const { result } = renderHook(() => useUpdateModelParameters())
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
@ -84,10 +98,11 @@ describe('useUpdateModelParameters', () => {
|
|||||||
await result.current.updateModelParameter(mockThread, mockValidParameters)
|
await result.current.updateModelParameter(mockThread, mockValidParameters)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
jest.runAllTimers()
|
||||||
|
|
||||||
// Check if the model parameters are valid before persisting
|
// Check if the model parameters are valid before persisting
|
||||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
|
||||||
assistants: [
|
id: 'assistant-1',
|
||||||
{
|
|
||||||
model: {
|
model: {
|
||||||
parameters: {
|
parameters: {
|
||||||
stop: ['<eos>', '<eos2>'],
|
stop: ['<eos>', '<eos2>'],
|
||||||
@ -110,18 +125,19 @@ describe('useUpdateModelParameters', () => {
|
|||||||
llama_model_path: 'path',
|
llama_model_path: 'path',
|
||||||
mmproj: 'mmproj',
|
mmproj: 'mmproj',
|
||||||
},
|
},
|
||||||
|
id: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
},
|
},
|
||||||
},
|
|
||||||
],
|
|
||||||
created: 0,
|
|
||||||
id: 'thread-1',
|
|
||||||
object: 'thread',
|
|
||||||
title: 'New Thread',
|
|
||||||
updated: 0,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not update invalid model parameters', async () => {
|
it('should not update invalid model parameters', async () => {
|
||||||
|
;(useAtom as jest.Mock).mockReturnValue([
|
||||||
|
{
|
||||||
|
id: 'assistant-1',
|
||||||
|
},
|
||||||
|
jest.fn(),
|
||||||
|
])
|
||||||
const mockInvalidParameters: any = {
|
const mockInvalidParameters: any = {
|
||||||
params: {
|
params: {
|
||||||
// Inference
|
// Inference
|
||||||
@ -153,7 +169,8 @@ describe('useUpdateModelParameters', () => {
|
|||||||
|
|
||||||
// Spy functions
|
// Spy functions
|
||||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
|
||||||
|
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
|
||||||
|
|
||||||
const { result } = renderHook(() => useUpdateModelParameters())
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
@ -164,14 +181,17 @@ describe('useUpdateModelParameters', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
jest.runAllTimers()
|
||||||
|
|
||||||
// Check if the model parameters are valid before persisting
|
// Check if the model parameters are valid before persisting
|
||||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
|
||||||
assistants: [
|
id: 'assistant-1',
|
||||||
{
|
|
||||||
model: {
|
model: {
|
||||||
|
engine: 'nitro',
|
||||||
|
id: 'model-1',
|
||||||
parameters: {
|
parameters: {
|
||||||
max_tokens: 1000,
|
|
||||||
token_limit: 1000,
|
token_limit: 1000,
|
||||||
|
max_tokens: 1000,
|
||||||
},
|
},
|
||||||
settings: {
|
settings: {
|
||||||
cpu_threads: 4,
|
cpu_threads: 4,
|
||||||
@ -183,17 +203,16 @@ describe('useUpdateModelParameters', () => {
|
|||||||
ngl: 12,
|
ngl: 12,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
|
||||||
],
|
|
||||||
created: 0,
|
|
||||||
id: 'thread-1',
|
|
||||||
object: 'thread',
|
|
||||||
title: 'New Thread',
|
|
||||||
updated: 0,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should update valid model parameters only', async () => {
|
it('should update valid model parameters only', async () => {
|
||||||
|
;(useAtom as jest.Mock).mockReturnValue([
|
||||||
|
{
|
||||||
|
id: 'assistant-1',
|
||||||
|
},
|
||||||
|
jest.fn(),
|
||||||
|
])
|
||||||
const mockInvalidParameters: any = {
|
const mockInvalidParameters: any = {
|
||||||
params: {
|
params: {
|
||||||
// Inference
|
// Inference
|
||||||
@ -225,8 +244,8 @@ describe('useUpdateModelParameters', () => {
|
|||||||
|
|
||||||
// Spy functions
|
// Spy functions
|
||||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
jest.spyOn(extension, 'modifyThread').mockReturnValue({})
|
||||||
|
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
|
||||||
const { result } = renderHook(() => useUpdateModelParameters())
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
@ -235,12 +254,14 @@ describe('useUpdateModelParameters', () => {
|
|||||||
mockInvalidParameters
|
mockInvalidParameters
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
jest.runAllTimers()
|
||||||
|
|
||||||
// Check if the model parameters are valid before persisting
|
// Check if the model parameters are valid before persisting
|
||||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
|
||||||
assistants: [
|
id: 'assistant-1',
|
||||||
{
|
|
||||||
model: {
|
model: {
|
||||||
|
engine: 'nitro',
|
||||||
|
id: 'model-1',
|
||||||
parameters: {
|
parameters: {
|
||||||
stop: ['<eos>'],
|
stop: ['<eos>'],
|
||||||
top_k: 0.7,
|
top_k: 0.7,
|
||||||
@ -260,55 +281,6 @@ describe('useUpdateModelParameters', () => {
|
|||||||
mmproj: 'mmproj',
|
mmproj: 'mmproj',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
|
||||||
],
|
|
||||||
created: 0,
|
|
||||||
id: 'thread-1',
|
|
||||||
object: 'thread',
|
|
||||||
title: 'New Thread',
|
|
||||||
updated: 0,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle missing modelId and engine gracefully', async () => {
|
|
||||||
const mockParametersWithoutModelIdAndEngine: any = {
|
|
||||||
params: {
|
|
||||||
stop: ['<eos>', '<eos2>'],
|
|
||||||
temperature: 0.5,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Spy functions
|
|
||||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
|
||||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useUpdateModelParameters())
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.updateModelParameter(
|
|
||||||
mockThread,
|
|
||||||
mockParametersWithoutModelIdAndEngine
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Check if the model parameters are valid before persisting
|
|
||||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
|
||||||
assistants: [
|
|
||||||
{
|
|
||||||
model: {
|
|
||||||
parameters: {
|
|
||||||
stop: ['<eos>', '<eos2>'],
|
|
||||||
temperature: 0.5,
|
|
||||||
},
|
|
||||||
settings: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
created: 0,
|
|
||||||
id: 'thread-1',
|
|
||||||
object: 'thread',
|
|
||||||
title: 'New Thread',
|
|
||||||
updated: 0,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -12,7 +12,10 @@ import {
|
|||||||
|
|
||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
|
import { useDebouncedCallback } from 'use-debounce'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
getActiveThreadModelParamsAtom,
|
getActiveThreadModelParamsAtom,
|
||||||
@ -29,11 +32,28 @@ export type UpdateModelParameter = {
|
|||||||
|
|
||||||
export default function useUpdateModelParameters() {
|
export default function useUpdateModelParameters() {
|
||||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||||
|
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
|
||||||
const [selectedModel] = useAtom(selectedModelAtom)
|
const [selectedModel] = useAtom(selectedModelAtom)
|
||||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||||
|
|
||||||
|
const updateAssistantExtension = (
|
||||||
|
threadId: string,
|
||||||
|
assistant: ThreadAssistantInfo
|
||||||
|
) => {
|
||||||
|
return extensionManager
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.modifyThreadAssistant(threadId, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateAssistantCallback = useDebouncedCallback(
|
||||||
|
updateAssistantExtension,
|
||||||
|
300
|
||||||
|
)
|
||||||
|
|
||||||
const updateModelParameter = useCallback(
|
const updateModelParameter = useCallback(
|
||||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
async (thread: Thread, settings: UpdateModelParameter) => {
|
||||||
|
if (!activeAssistant) return
|
||||||
|
|
||||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||||
const updatedModelParams = settings.modelId
|
const updatedModelParams = settings.modelId
|
||||||
? toUpdateSettings
|
? toUpdateSettings
|
||||||
@ -48,30 +68,34 @@ export default function useUpdateModelParameters() {
|
|||||||
setThreadModelParams(thread.id, updatedModelParams)
|
setThreadModelParams(thread.id, updatedModelParams)
|
||||||
const runtimeParams = extractInferenceParams(updatedModelParams)
|
const runtimeParams = extractInferenceParams(updatedModelParams)
|
||||||
const settingParams = extractModelLoadParams(updatedModelParams)
|
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||||
|
const assistantInfo = {
|
||||||
const assistants = thread.assistants.map(
|
...activeAssistant,
|
||||||
(assistant: ThreadAssistantInfo) => {
|
model: {
|
||||||
assistant.model.parameters = runtimeParams
|
...activeAssistant?.model,
|
||||||
assistant.model.settings = settingParams
|
parameters: runtimeParams,
|
||||||
if (selectedModel) {
|
settings: settingParams,
|
||||||
assistant.model.id = settings.modelId ?? selectedModel?.id
|
id: settings.modelId ?? selectedModel?.id ?? activeAssistant.model.id,
|
||||||
assistant.model.engine = settings.engine ?? selectedModel?.engine
|
engine:
|
||||||
}
|
settings.engine ??
|
||||||
return assistant
|
selectedModel?.engine ??
|
||||||
}
|
activeAssistant.model.engine,
|
||||||
)
|
|
||||||
|
|
||||||
// update thread
|
|
||||||
const updatedThread: Thread = {
|
|
||||||
...thread,
|
|
||||||
assistants,
|
|
||||||
}
|
|
||||||
|
|
||||||
await extensionManager
|
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
|
||||||
?.saveThread(updatedThread)
|
|
||||||
},
|
},
|
||||||
[activeModelParams, selectedModel, setThreadModelParams]
|
}
|
||||||
|
setActiveAssistant(assistantInfo)
|
||||||
|
|
||||||
|
updateAssistantCallback(thread.id, assistantInfo)
|
||||||
|
},
|
||||||
|
[
|
||||||
|
activeAssistant,
|
||||||
|
selectedModel?.parameters,
|
||||||
|
selectedModel?.settings,
|
||||||
|
selectedModel?.id,
|
||||||
|
selectedModel?.engine,
|
||||||
|
activeModelParams,
|
||||||
|
setThreadModelParams,
|
||||||
|
setActiveAssistant,
|
||||||
|
updateAssistantCallback,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
const processStopWords = (params: ModelParams): ModelParams => {
|
const processStopWords = (params: ModelParams): ModelParams => {
|
||||||
|
|||||||
@ -37,5 +37,5 @@ const config = {
|
|||||||
// module.exports = createJestConfig(config)
|
// module.exports = createJestConfig(config)
|
||||||
module.exports = async () => ({
|
module.exports = async () => ({
|
||||||
...(await createJestConfig(config)()),
|
...(await createJestConfig(config)()),
|
||||||
transformIgnorePatterns: ['/node_modules/(?!(layerr)/)'],
|
transformIgnorePatterns: ['/node_modules/(?!(layerr|nanoid|@uppy|preact)/)'],
|
||||||
})
|
})
|
||||||
|
|||||||
@ -35,7 +35,7 @@ const nextConfig = {
|
|||||||
POSTHOG_HOST: JSON.stringify(process.env.POSTHOG_HOST),
|
POSTHOG_HOST: JSON.stringify(process.env.POSTHOG_HOST),
|
||||||
ANALYTICS_HOST: JSON.stringify(process.env.ANALYTICS_HOST),
|
ANALYTICS_HOST: JSON.stringify(process.env.ANALYTICS_HOST),
|
||||||
API_BASE_URL: JSON.stringify(
|
API_BASE_URL: JSON.stringify(
|
||||||
process.env.API_BASE_URL ?? 'http://localhost:1337'
|
process.env.API_BASE_URL ?? 'http://127.0.0.1:39291'
|
||||||
),
|
),
|
||||||
isMac: process.platform === 'darwin',
|
isMac: process.platform === 'darwin',
|
||||||
isWindows: process.platform === 'win32',
|
isWindows: process.platform === 'win32',
|
||||||
|
|||||||
@ -17,6 +17,9 @@
|
|||||||
"@janhq/core": "link:./core",
|
"@janhq/core": "link:./core",
|
||||||
"@janhq/joi": "link:./joi",
|
"@janhq/joi": "link:./joi",
|
||||||
"@tanstack/react-virtual": "^3.10.9",
|
"@tanstack/react-virtual": "^3.10.9",
|
||||||
|
"@uppy/core": "^4.3.0",
|
||||||
|
"@uppy/react": "^4.0.4",
|
||||||
|
"@uppy/xhr-upload": "^4.2.3",
|
||||||
"autoprefixer": "10.4.16",
|
"autoprefixer": "10.4.16",
|
||||||
"class-variance-authority": "^0.7.0",
|
"class-variance-authority": "^0.7.0",
|
||||||
"framer-motion": "^10.16.4",
|
"framer-motion": "^10.16.4",
|
||||||
|
|||||||
@ -7,6 +7,8 @@ import { useAtomValue, useSetAtom } from 'jotai'
|
|||||||
import { useActiveModel } from '@/hooks/useActiveModel'
|
import { useActiveModel } from '@/hooks/useActiveModel'
|
||||||
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||||
import AssistantSetting from './index'
|
import AssistantSetting from './index'
|
||||||
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
|
|
||||||
jest.mock('jotai', () => {
|
jest.mock('jotai', () => {
|
||||||
const originalModule = jest.requireActual('jotai')
|
const originalModule = jest.requireActual('jotai')
|
||||||
@ -68,6 +70,7 @@ describe('AssistantSetting Component', () => {
|
|||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
|
jest.useFakeTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
test('renders AssistantSetting component with proper data', async () => {
|
test('renders AssistantSetting component with proper data', async () => {
|
||||||
@ -75,7 +78,14 @@ describe('AssistantSetting Component', () => {
|
|||||||
;(useSetAtom as jest.Mock).mockImplementationOnce(
|
;(useSetAtom as jest.Mock).mockImplementationOnce(
|
||||||
() => setEngineParamsUpdate
|
() => setEngineParamsUpdate
|
||||||
)
|
)
|
||||||
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
|
;(useAtomValue as jest.Mock).mockImplementation((atom) => {
|
||||||
|
switch (atom) {
|
||||||
|
case activeThreadAtom:
|
||||||
|
return mockActiveThread
|
||||||
|
case activeAssistantAtom:
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
})
|
||||||
const updateThreadMetadata = jest.fn()
|
const updateThreadMetadata = jest.fn()
|
||||||
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() })
|
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() })
|
||||||
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
|
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
|
||||||
@ -98,7 +108,14 @@ describe('AssistantSetting Component', () => {
|
|||||||
const setEngineParamsUpdate = jest.fn()
|
const setEngineParamsUpdate = jest.fn()
|
||||||
const updateThreadMetadata = jest.fn()
|
const updateThreadMetadata = jest.fn()
|
||||||
const stopModel = jest.fn()
|
const stopModel = jest.fn()
|
||||||
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
|
;(useAtomValue as jest.Mock).mockImplementation((atom) => {
|
||||||
|
switch (atom) {
|
||||||
|
case activeThreadAtom:
|
||||||
|
return mockActiveThread
|
||||||
|
case activeAssistantAtom:
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
})
|
||||||
;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate)
|
;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate)
|
||||||
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel })
|
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel })
|
||||||
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
|
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
|||||||
|
|
||||||
import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent'
|
import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent'
|
||||||
|
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import {
|
||||||
activeThreadAtom,
|
activeThreadAtom,
|
||||||
engineParamsUpdateAtom,
|
engineParamsUpdateAtom,
|
||||||
@ -19,13 +20,14 @@ type Props = {
|
|||||||
|
|
||||||
const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const { updateThreadMetadata } = useCreateNewThread()
|
const { updateThreadMetadata } = useCreateNewThread()
|
||||||
const { stopModel } = useActiveModel()
|
const { stopModel } = useActiveModel()
|
||||||
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
|
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
|
||||||
|
|
||||||
const onValueChanged = useCallback(
|
const onValueChanged = useCallback(
|
||||||
(key: string, value: string | number | boolean | string[]) => {
|
(key: string, value: string | number | boolean | string[]) => {
|
||||||
if (!activeThread) return
|
if (!activeThread || !activeAssistant) return
|
||||||
const shouldReloadModel =
|
const shouldReloadModel =
|
||||||
componentData.find((x) => x.key === key)?.requireModelReload ?? false
|
componentData.find((x) => x.key === key)?.requireModelReload ?? false
|
||||||
if (shouldReloadModel) {
|
if (shouldReloadModel) {
|
||||||
@ -34,40 +36,40 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
activeThread.assistants[0].tools &&
|
activeAssistant?.tools &&
|
||||||
(key === 'chunk_overlap' || key === 'chunk_size')
|
(key === 'chunk_overlap' || key === 'chunk_size')
|
||||||
) {
|
) {
|
||||||
if (
|
if (
|
||||||
activeThread.assistants[0].tools[0]?.settings?.chunk_size <
|
activeAssistant.tools[0]?.settings?.chunk_size <
|
||||||
activeThread.assistants[0].tools[0]?.settings?.chunk_overlap
|
activeAssistant.tools[0]?.settings?.chunk_overlap
|
||||||
) {
|
) {
|
||||||
activeThread.assistants[0].tools[0].settings.chunk_overlap =
|
activeAssistant.tools[0].settings.chunk_overlap =
|
||||||
activeThread.assistants[0].tools[0].settings.chunk_size
|
activeAssistant.tools[0].settings.chunk_size
|
||||||
}
|
}
|
||||||
if (
|
if (
|
||||||
key === 'chunk_size' &&
|
key === 'chunk_size' &&
|
||||||
value < activeThread.assistants[0].tools[0].settings?.chunk_overlap
|
value < activeAssistant.tools[0].settings?.chunk_overlap
|
||||||
) {
|
) {
|
||||||
activeThread.assistants[0].tools[0].settings.chunk_overlap = value
|
activeAssistant.tools[0].settings.chunk_overlap = value
|
||||||
} else if (
|
} else if (
|
||||||
key === 'chunk_overlap' &&
|
key === 'chunk_overlap' &&
|
||||||
value > activeThread.assistants[0].tools[0].settings?.chunk_size
|
value > activeAssistant.tools[0].settings?.chunk_size
|
||||||
) {
|
) {
|
||||||
activeThread.assistants[0].tools[0].settings.chunk_size = value
|
activeAssistant.tools[0].settings.chunk_size = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
updateThreadMetadata({
|
updateThreadMetadata({
|
||||||
...activeThread,
|
...activeThread,
|
||||||
assistants: [
|
assistants: [
|
||||||
{
|
{
|
||||||
...activeThread.assistants[0],
|
...activeAssistant,
|
||||||
tools: [
|
tools: [
|
||||||
{
|
{
|
||||||
type: 'retrieval',
|
type: 'retrieval',
|
||||||
enabled: true,
|
enabled: true,
|
||||||
settings: {
|
settings: {
|
||||||
...(activeThread.assistants[0].tools &&
|
...(activeAssistant.tools &&
|
||||||
activeThread.assistants[0].tools[0]?.settings),
|
activeAssistant.tools[0]?.settings),
|
||||||
[key]: value,
|
[key]: value,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -77,6 +79,7 @@ const AssistantSetting: React.FC<Props> = ({ componentData }) => {
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
|
activeAssistant,
|
||||||
activeThread,
|
activeThread,
|
||||||
componentData,
|
componentData,
|
||||||
setEngineParamsUpdate,
|
setEngineParamsUpdate,
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
|
|||||||
|
|
||||||
import useSendChatMessage from '@/hooks/useSendChatMessage'
|
import useSendChatMessage from '@/hooks/useSendChatMessage'
|
||||||
|
|
||||||
|
import { uploader } from '@/utils/file'
|
||||||
import { isLocalEngine } from '@/utils/modelEngine'
|
import { isLocalEngine } from '@/utils/modelEngine'
|
||||||
|
|
||||||
import FileUploadPreview from '../FileUploadPreview'
|
import FileUploadPreview from '../FileUploadPreview'
|
||||||
@ -33,6 +34,7 @@ import RichTextEditor from './RichTextEditor'
|
|||||||
|
|
||||||
import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
|
import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
|
||||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
|
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
|
||||||
@ -67,8 +69,10 @@ const ChatInput = () => {
|
|||||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||||
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
|
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
|
||||||
const threadStates = useAtomValue(threadStatesAtom)
|
const threadStates = useAtomValue(threadStatesAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const { stopInference } = useActiveModel()
|
const { stopInference } = useActiveModel()
|
||||||
|
|
||||||
|
const upload = uploader()
|
||||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||||
activeTabThreadRightPanelAtom
|
activeTabThreadRightPanelAtom
|
||||||
)
|
)
|
||||||
@ -102,18 +106,26 @@ const ChatInput = () => {
|
|||||||
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
const handleFileChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
const file = event.target.files?.[0]
|
const file = event.target.files?.[0]
|
||||||
if (!file) return
|
if (!file) return
|
||||||
setFileUpload([{ file: file, type: 'pdf' }])
|
upload.addFile(file)
|
||||||
|
upload.upload().then((data) => {
|
||||||
|
setFileUpload({
|
||||||
|
file: file,
|
||||||
|
type: 'pdf',
|
||||||
|
id: data?.successful?.[0]?.response?.body?.id,
|
||||||
|
name: data?.successful?.[0]?.response?.body?.filename,
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleImageChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
const handleImageChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
const file = event.target.files?.[0]
|
const file = event.target.files?.[0]
|
||||||
if (!file) return
|
if (!file) return
|
||||||
setFileUpload([{ file: file, type: 'image' }])
|
setFileUpload({ file: file, type: 'image' })
|
||||||
}
|
}
|
||||||
|
|
||||||
const renderPreview = (fileUpload: any) => {
|
const renderPreview = (fileUpload: any) => {
|
||||||
if (fileUpload.length > 0) {
|
if (fileUpload) {
|
||||||
if (fileUpload[0].type === 'image') {
|
if (fileUpload.type === 'image') {
|
||||||
return <ImageUploadPreview file={fileUpload[0].file} />
|
return <ImageUploadPreview file={fileUpload[0].file} />
|
||||||
} else {
|
} else {
|
||||||
return <FileUploadPreview />
|
return <FileUploadPreview />
|
||||||
@ -130,7 +142,7 @@ const ChatInput = () => {
|
|||||||
'relative mb-1 max-h-[400px] resize-none rounded-lg border border-[hsla(var(--app-border))] p-3 pr-20',
|
'relative mb-1 max-h-[400px] resize-none rounded-lg border border-[hsla(var(--app-border))] p-3 pr-20',
|
||||||
'focus-within:outline-none focus-visible:outline-0 focus-visible:ring-1 focus-visible:ring-[hsla(var(--primary-bg))] focus-visible:ring-offset-0',
|
'focus-within:outline-none focus-visible:outline-0 focus-visible:ring-1 focus-visible:ring-[hsla(var(--primary-bg))] focus-visible:ring-offset-0',
|
||||||
'overflow-y-auto',
|
'overflow-y-auto',
|
||||||
fileUpload.length && 'rounded-t-none',
|
fileUpload && 'rounded-t-none',
|
||||||
experimentalFeature && 'pl-10',
|
experimentalFeature && 'pl-10',
|
||||||
activeSettingInputBox && 'pb-14 pr-16'
|
activeSettingInputBox && 'pb-14 pr-16'
|
||||||
)}
|
)}
|
||||||
@ -152,10 +164,10 @@ const ChatInput = () => {
|
|||||||
className="absolute left-3 top-2.5"
|
className="absolute left-3 top-2.5"
|
||||||
onClick={(e) => {
|
onClick={(e) => {
|
||||||
if (
|
if (
|
||||||
fileUpload.length > 0 ||
|
!!fileUpload ||
|
||||||
(activeThread?.assistants[0].tools &&
|
(activeAssistant?.tools &&
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
!activeAssistant?.tools[0]?.enabled &&
|
||||||
!activeThread?.assistants[0].model.settings?.vision_model)
|
!activeAssistant?.model.settings?.vision_model)
|
||||||
) {
|
) {
|
||||||
e.stopPropagation()
|
e.stopPropagation()
|
||||||
} else {
|
} else {
|
||||||
@ -171,26 +183,24 @@ const ChatInput = () => {
|
|||||||
}
|
}
|
||||||
disabled={
|
disabled={
|
||||||
isModelSupportRagAndTools &&
|
isModelSupportRagAndTools &&
|
||||||
activeThread?.assistants[0].tools &&
|
activeAssistant?.tools &&
|
||||||
activeThread?.assistants[0].tools[0]?.enabled
|
activeAssistant?.tools[0]?.enabled
|
||||||
}
|
}
|
||||||
content={
|
content={
|
||||||
<>
|
<>
|
||||||
{fileUpload.length > 0 ||
|
{!!fileUpload ||
|
||||||
(activeThread?.assistants[0].tools &&
|
(activeAssistant?.tools &&
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
!activeAssistant?.tools[0]?.enabled &&
|
||||||
!activeThread?.assistants[0].model.settings
|
!activeAssistant?.model.settings?.vision_model && (
|
||||||
?.vision_model && (
|
|
||||||
<>
|
<>
|
||||||
{fileUpload.length !== 0 && (
|
{!!fileUpload && (
|
||||||
<span>
|
<span>
|
||||||
Currently, we only support 1 attachment at the same
|
Currently, we only support 1 attachment at the same
|
||||||
time.
|
time.
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
{activeThread?.assistants[0].tools &&
|
{activeAssistant?.tools &&
|
||||||
activeThread?.assistants[0].tools[0]?.enabled ===
|
activeAssistant?.tools[0]?.enabled === false &&
|
||||||
false &&
|
|
||||||
isModelSupportRagAndTools && (
|
isModelSupportRagAndTools && (
|
||||||
<span>
|
<span>
|
||||||
Turn on Retrieval in Tools settings to use this
|
Turn on Retrieval in Tools settings to use this
|
||||||
@ -221,14 +231,12 @@ const ChatInput = () => {
|
|||||||
<li
|
<li
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]',
|
'text-[hsla(var(--text-secondary)] hover:bg-secondary flex w-full items-center space-x-2 px-4 py-2 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]',
|
||||||
activeThread?.assistants[0].model.settings?.vision_model
|
activeAssistant?.model.settings?.vision_model
|
||||||
? 'cursor-pointer'
|
? 'cursor-pointer'
|
||||||
: 'cursor-not-allowed opacity-50'
|
: 'cursor-not-allowed opacity-50'
|
||||||
)}
|
)}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (
|
if (activeAssistant?.model.settings?.vision_model) {
|
||||||
activeThread?.assistants[0].model.settings?.vision_model
|
|
||||||
) {
|
|
||||||
imageInputRef.current?.click()
|
imageInputRef.current?.click()
|
||||||
setShowAttacmentMenus(false)
|
setShowAttacmentMenus(false)
|
||||||
}
|
}
|
||||||
@ -239,9 +247,7 @@ const ChatInput = () => {
|
|||||||
</li>
|
</li>
|
||||||
}
|
}
|
||||||
content="This feature only supports multimodal models."
|
content="This feature only supports multimodal models."
|
||||||
disabled={
|
disabled={activeAssistant?.model.settings?.vision_model}
|
||||||
activeThread?.assistants[0].model.settings?.vision_model
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
<Tooltip
|
<Tooltip
|
||||||
side="bottom"
|
side="bottom"
|
||||||
@ -261,8 +267,8 @@ const ChatInput = () => {
|
|||||||
</li>
|
</li>
|
||||||
}
|
}
|
||||||
content={
|
content={
|
||||||
(!activeThread?.assistants[0].tools ||
|
(!activeAssistant?.tools ||
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled) && (
|
!activeAssistant?.tools[0]?.enabled) && (
|
||||||
<span>
|
<span>
|
||||||
Turn on Retrieval in Assistant Settings to use this
|
Turn on Retrieval in Assistant Settings to use this
|
||||||
feature.
|
feature.
|
||||||
|
|||||||
@ -72,7 +72,8 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
|
|||||||
}, [editPrompt])
|
}, [editPrompt])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setEditPrompt(message.content[0]?.text?.value)
|
if (message.content?.[0]?.text?.value)
|
||||||
|
setEditPrompt(message.content[0].text.value)
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
@ -80,19 +81,17 @@ const EditChatInput: React.FC<Props> = ({ message }) => {
|
|||||||
setEditMessage('')
|
setEditMessage('')
|
||||||
const messageIdx = messages.findIndex((msg) => msg.id === message.id)
|
const messageIdx = messages.findIndex((msg) => msg.id === message.id)
|
||||||
const newMessages = messages.slice(0, messageIdx)
|
const newMessages = messages.slice(0, messageIdx)
|
||||||
if (activeThread) {
|
const toDeleteMessages = messages.slice(messageIdx)
|
||||||
setMessages(activeThread.id, newMessages)
|
const threadId = messages[0].thread_id
|
||||||
await extensionManager
|
await Promise.all(
|
||||||
|
toDeleteMessages.map(async (message) =>
|
||||||
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.writeMessages(
|
?.deleteMessage(message.thread_id, message.id)
|
||||||
activeThread.id,
|
|
||||||
// Remove all of the messages below this
|
|
||||||
newMessages
|
|
||||||
)
|
)
|
||||||
.then(() => {
|
)
|
||||||
sendChatMessage(editPrompt, newMessages)
|
setMessages(threadId, newMessages)
|
||||||
})
|
sendChatMessage(editPrompt, false, newMessages)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
|||||||
@ -15,21 +15,22 @@ const FileUploadPreview = () => {
|
|||||||
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
const setCurrentPrompt = useSetAtom(currentPromptAtom)
|
||||||
|
|
||||||
const onDeleteClick = () => {
|
const onDeleteClick = () => {
|
||||||
setFileUpload([])
|
setFileUpload(undefined)
|
||||||
setCurrentPrompt('')
|
setCurrentPrompt('')
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col rounded-t-lg border border-b-0 border-[hsla(var(--app-border))] p-4">
|
<div className="flex flex-col rounded-t-lg border border-b-0 border-[hsla(var(--app-border))] p-4">
|
||||||
|
{!!fileUpload && (
|
||||||
<div className="bg-secondary relative inline-flex w-60 space-x-3 rounded-lg p-4">
|
<div className="bg-secondary relative inline-flex w-60 space-x-3 rounded-lg p-4">
|
||||||
<Icon type={fileUpload[0].type} />
|
<Icon type={fileUpload?.type} />
|
||||||
|
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<h6 className="line-clamp-1 w-3/4 truncate font-medium">
|
<h6 className="line-clamp-1 w-3/4 truncate font-medium">
|
||||||
{fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')}
|
{fileUpload?.file.name.replaceAll(/[-._]/g, ' ')}
|
||||||
</h6>
|
</h6>
|
||||||
<p className="text-[hsla(var(--text-secondary)]">
|
<p className="text-[hsla(var(--text-secondary)]">
|
||||||
{toGibibytes(fileUpload[0].file.size)}
|
{toGibibytes(fileUpload?.file.size)}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -40,6 +41,7 @@ const FileUploadPreview = () => {
|
|||||||
<XIcon size={14} className="text-background" />
|
<XIcon size={14} className="text-background" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,7 @@ const ImageUploadPreview: React.FC<Props> = ({ file }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const onDeleteClick = () => {
|
const onDeleteClick = () => {
|
||||||
setFileUpload([])
|
setFileUpload(undefined)
|
||||||
setCurrentPrompt('')
|
setCurrentPrompt('')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,15 +10,15 @@ import { MainViewState } from '@/constants/screens'
|
|||||||
import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
||||||
|
|
||||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
|
||||||
|
|
||||||
const LoadModelError = () => {
|
const LoadModelError = () => {
|
||||||
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
|
||||||
const loadModelError = useAtomValue(loadModelErrorAtom)
|
const loadModelError = useAtomValue(loadModelErrorAtom)
|
||||||
const setMainState = useSetAtom(mainViewStateAtom)
|
const setMainState = useSetAtom(mainViewStateAtom)
|
||||||
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
|
|
||||||
const ErrorMessage = () => {
|
const ErrorMessage = () => {
|
||||||
if (
|
if (
|
||||||
@ -33,9 +33,9 @@ const LoadModelError = () => {
|
|||||||
className="cursor-pointer font-medium text-[hsla(var(--app-link))]"
|
className="cursor-pointer font-medium text-[hsla(var(--app-link))]"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
setMainState(MainViewState.Settings)
|
setMainState(MainViewState.Settings)
|
||||||
if (activeThread?.assistants[0]?.model.engine) {
|
if (activeAssistant?.model.engine) {
|
||||||
const engine = EngineManager.instance().get(
|
const engine = EngineManager.instance().get(
|
||||||
activeThread.assistants[0].model.engine
|
activeAssistant.model.engine
|
||||||
)
|
)
|
||||||
engine?.name && setSelectedSettingScreen(engine.name)
|
engine?.name && setSelectedSettingScreen(engine.name)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,15 +55,11 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
|||||||
.slice(-1)[0]
|
.slice(-1)[0]
|
||||||
|
|
||||||
if (thread) {
|
if (thread) {
|
||||||
// Should also delete error messages to clear out the error state
|
// TODO: Should also delete error messages to clear out the error state
|
||||||
await extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.writeMessages(
|
?.deleteMessage(thread.id, message.id)
|
||||||
thread.id,
|
.catch(console.error)
|
||||||
messages.filter(
|
|
||||||
(msg) => msg.id !== message.id && msg.status !== MessageStatus.Error
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
const updatedThread: Thread = {
|
const updatedThread: Thread = {
|
||||||
...thread,
|
...thread,
|
||||||
@ -74,7 +70,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
|||||||
)[
|
)[
|
||||||
messages.filter((msg) => msg.role === ChatCompletionRole.Assistant)
|
messages.filter((msg) => msg.role === ChatCompletionRole.Assistant)
|
||||||
.length - 1
|
.length - 1
|
||||||
]?.content[0]?.text.value,
|
]?.content[0]?.text?.value,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,10 +85,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
|||||||
setEditMessage(message.id ?? '')
|
setEditMessage(message.id ?? '')
|
||||||
}
|
}
|
||||||
|
|
||||||
const onRegenerateClick = async () => {
|
|
||||||
resendChatMessage(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (message.status === MessageStatus.Pending) return null
|
if (message.status === MessageStatus.Pending) return null
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -118,11 +110,10 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
|||||||
|
|
||||||
{message.id === messages[messages.length - 1]?.id &&
|
{message.id === messages[messages.length - 1]?.id &&
|
||||||
messages[messages.length - 1].status !== MessageStatus.Error &&
|
messages[messages.length - 1].status !== MessageStatus.Error &&
|
||||||
messages[messages.length - 1].content[0]?.type !==
|
!messages[messages.length - 1].attachments?.length && (
|
||||||
ContentType.Pdf && (
|
|
||||||
<div
|
<div
|
||||||
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
|
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
|
||||||
onClick={onRegenerateClick}
|
onClick={resendChatMessage}
|
||||||
>
|
>
|
||||||
<Tooltip
|
<Tooltip
|
||||||
trigger={
|
trigger={
|
||||||
|
|||||||
@ -11,15 +11,7 @@ import { openFileTitle } from '@/utils/titleUtils'
|
|||||||
|
|
||||||
import Icon from '../FileUploadPreview/Icon'
|
import Icon from '../FileUploadPreview/Icon'
|
||||||
|
|
||||||
const DocMessage = ({
|
const DocMessage = ({ id, name }: { id: string; name?: string }) => {
|
||||||
id,
|
|
||||||
name,
|
|
||||||
size,
|
|
||||||
}: {
|
|
||||||
id: string
|
|
||||||
name?: string
|
|
||||||
size?: number
|
|
||||||
}) => {
|
|
||||||
const { onViewFile, onViewFileContainer } = usePath()
|
const { onViewFile, onViewFileContainer } = usePath()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -44,9 +36,9 @@ const DocMessage = ({
|
|||||||
<h6 className="line-clamp-1 w-4/5 font-medium">
|
<h6 className="line-clamp-1 w-4/5 font-medium">
|
||||||
{name?.replaceAll(/[-._]/g, ' ')}
|
{name?.replaceAll(/[-._]/g, ' ')}
|
||||||
</h6>
|
</h6>
|
||||||
<p className="text-[hsla(var(--text-secondary)]">
|
{/* <p className="text-[hsla(var(--text-secondary)]">
|
||||||
{toGibibytes(Number(size))}
|
{toGibibytes(Number(size))}
|
||||||
</p>
|
</p> */}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import { memo, useMemo } from 'react'
|
import { memo } from 'react'
|
||||||
|
|
||||||
import { ThreadContent } from '@janhq/core'
|
|
||||||
import { Tooltip } from '@janhq/joi'
|
import { Tooltip } from '@janhq/joi'
|
||||||
|
|
||||||
import { FolderOpenIcon } from 'lucide-react'
|
import { FolderOpenIcon } from 'lucide-react'
|
||||||
@ -11,21 +10,13 @@ import { openFileTitle } from '@/utils/titleUtils'
|
|||||||
|
|
||||||
import { RelativeImage } from '../TextMessage/RelativeImage'
|
import { RelativeImage } from '../TextMessage/RelativeImage'
|
||||||
|
|
||||||
const ImageMessage = ({ content }: { content: ThreadContent }) => {
|
const ImageMessage = ({ image }: { image: string }) => {
|
||||||
const { onViewFile, onViewFileContainer } = usePath()
|
const { onViewFile, onViewFileContainer } = usePath()
|
||||||
|
|
||||||
const annotation = useMemo(
|
|
||||||
() => content?.text?.annotations[0] ?? '',
|
|
||||||
[content]
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="group/image relative mb-2 inline-flex cursor-pointer overflow-hidden rounded-xl">
|
<div className="group/image relative mb-2 inline-flex cursor-pointer overflow-hidden rounded-xl">
|
||||||
<div className="left-0 top-0 z-20 h-full w-full group-hover/image:inline-block">
|
<div className="left-0 top-0 z-20 h-full w-full group-hover/image:inline-block">
|
||||||
<RelativeImage
|
<RelativeImage src={image} onClick={() => onViewFile(image)} />
|
||||||
src={annotation}
|
|
||||||
onClick={() => onViewFile(annotation)}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
<Tooltip
|
<Tooltip
|
||||||
trigger={
|
trigger={
|
||||||
|
|||||||
@ -17,11 +17,11 @@ import DocMessage from './DocMessage'
|
|||||||
import ImageMessage from './ImageMessage'
|
import ImageMessage from './ImageMessage'
|
||||||
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||||
|
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import {
|
||||||
editMessageAtom,
|
editMessageAtom,
|
||||||
tokenSpeedAtom,
|
tokenSpeedAtom,
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
|
||||||
|
|
||||||
const MessageContainer: React.FC<
|
const MessageContainer: React.FC<
|
||||||
ThreadMessage & { isCurrentMessage: boolean }
|
ThreadMessage & { isCurrentMessage: boolean }
|
||||||
@ -29,18 +29,23 @@ const MessageContainer: React.FC<
|
|||||||
const isUser = props.role === ChatCompletionRole.User
|
const isUser = props.role === ChatCompletionRole.User
|
||||||
const isSystem = props.role === ChatCompletionRole.System
|
const isSystem = props.role === ChatCompletionRole.System
|
||||||
const editMessage = useAtomValue(editMessageAtom)
|
const editMessage = useAtomValue(editMessageAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const tokenSpeed = useAtomValue(tokenSpeedAtom)
|
const tokenSpeed = useAtomValue(tokenSpeedAtom)
|
||||||
|
|
||||||
const text = useMemo(
|
const text = useMemo(
|
||||||
() => props.content[0]?.text?.value ?? '',
|
() =>
|
||||||
|
props.content.find((e) => e.type === ContentType.Text)?.text?.value ?? '',
|
||||||
[props.content]
|
[props.content]
|
||||||
)
|
)
|
||||||
const messageType = useMemo(
|
|
||||||
() => props.content[0]?.type ?? '',
|
const image = useMemo(
|
||||||
|
() =>
|
||||||
|
props.content.find((e) => e.type === ContentType.Image)?.image_url?.url,
|
||||||
[props.content]
|
[props.content]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const attachedFile = useMemo(() => 'attachments' in props, [props])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="group relative mx-auto max-w-[700px] p-4">
|
<div className="group relative mx-auto max-w-[700px] p-4">
|
||||||
<div
|
<div
|
||||||
@ -75,10 +80,10 @@ const MessageContainer: React.FC<
|
|||||||
>
|
>
|
||||||
{isUser
|
{isUser
|
||||||
? props.role
|
? props.role
|
||||||
: (activeThread?.assistants[0].assistant_name ?? props.role)}
|
: (activeAssistant?.assistant_name ?? props.role)}
|
||||||
</div>
|
</div>
|
||||||
<p className="text-xs font-medium text-gray-400">
|
<p className="text-xs font-medium text-gray-400">
|
||||||
{displayDate(props.created)}
|
{props.created && displayDate(props.created ?? new Date())}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -111,16 +116,8 @@ const MessageContainer: React.FC<
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
{messageType === ContentType.Image && (
|
{image && <ImageMessage image={image} />}
|
||||||
<ImageMessage content={props.content[0]} />
|
{attachedFile && <DocMessage id={props.id} name={props.id} />}
|
||||||
)}
|
|
||||||
{messageType === ContentType.Pdf && (
|
|
||||||
<DocMessage
|
|
||||||
id={props.id}
|
|
||||||
name={props.content[0]?.text?.name}
|
|
||||||
size={props.content[0]?.text?.size}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{editMessage === props.id ? (
|
{editMessage === props.id ? (
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
@ -22,11 +22,14 @@ import { reloadModelAtom } from '@/hooks/useSendChatMessage'
|
|||||||
|
|
||||||
import ChatBody from '@/screens/Thread/ThreadCenterPanel/ChatBody'
|
import ChatBody from '@/screens/Thread/ThreadCenterPanel/ChatBody'
|
||||||
|
|
||||||
|
import { uploader } from '@/utils/file'
|
||||||
|
|
||||||
import ChatInput from './ChatInput'
|
import ChatInput from './ChatInput'
|
||||||
import RequestDownloadModel from './RequestDownloadModel'
|
import RequestDownloadModel from './RequestDownloadModel'
|
||||||
|
|
||||||
import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom'
|
import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom'
|
||||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -55,9 +58,9 @@ const ThreadCenterPanel = () => {
|
|||||||
const setFileUpload = useSetAtom(fileUploadAtom)
|
const setFileUpload = useSetAtom(fileUploadAtom)
|
||||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const acceptedFormat: Accept = activeThread?.assistants[0].model.settings
|
const upload = uploader()
|
||||||
?.vision_model
|
const acceptedFormat: Accept = activeAssistant?.model.settings?.vision_model
|
||||||
? {
|
? {
|
||||||
'application/pdf': ['.pdf'],
|
'application/pdf': ['.pdf'],
|
||||||
'image/jpeg': ['.jpeg'],
|
'image/jpeg': ['.jpeg'],
|
||||||
@ -78,14 +81,13 @@ const ThreadCenterPanel = () => {
|
|||||||
if (!experimentalFeature) return
|
if (!experimentalFeature) return
|
||||||
if (
|
if (
|
||||||
e.dataTransfer.items.length === 1 &&
|
e.dataTransfer.items.length === 1 &&
|
||||||
((activeThread?.assistants[0].tools &&
|
((activeAssistant?.tools && activeAssistant?.tools[0]?.enabled) ||
|
||||||
activeThread?.assistants[0].tools[0]?.enabled) ||
|
activeAssistant?.model.settings?.vision_model)
|
||||||
activeThread?.assistants[0].model.settings?.vision_model)
|
|
||||||
) {
|
) {
|
||||||
setDragOver(true)
|
setDragOver(true)
|
||||||
} else if (
|
} else if (
|
||||||
activeThread?.assistants[0].tools &&
|
activeAssistant?.tools &&
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled
|
!activeAssistant?.tools[0]?.enabled
|
||||||
) {
|
) {
|
||||||
setDragRejected({ code: 'retrieval-off' })
|
setDragRejected({ code: 'retrieval-off' })
|
||||||
} else {
|
} else {
|
||||||
@ -93,27 +95,36 @@ const ThreadCenterPanel = () => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onDragLeave: () => setDragOver(false),
|
onDragLeave: () => setDragOver(false),
|
||||||
onDrop: (files, rejectFiles) => {
|
onDrop: async (files, rejectFiles) => {
|
||||||
// Retrieval file drag and drop is experimental feature
|
// Retrieval file drag and drop is experimental feature
|
||||||
if (!experimentalFeature) return
|
if (!experimentalFeature) return
|
||||||
if (
|
if (
|
||||||
!files ||
|
!files ||
|
||||||
files.length !== 1 ||
|
files.length !== 1 ||
|
||||||
rejectFiles.length !== 0 ||
|
rejectFiles.length !== 0 ||
|
||||||
(activeThread?.assistants[0].tools &&
|
(activeAssistant?.tools &&
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled &&
|
!activeAssistant?.tools[0]?.enabled &&
|
||||||
!activeThread?.assistants[0].model.settings?.vision_model)
|
!activeAssistant?.model.settings?.vision_model)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
const imageType = files[0]?.type.includes('image')
|
const imageType = files[0]?.type.includes('image')
|
||||||
setFileUpload([{ file: files[0], type: imageType ? 'image' : 'pdf' }])
|
if (imageType) {
|
||||||
|
setFileUpload({ file: files[0], type: 'image' })
|
||||||
|
} else {
|
||||||
|
upload.addFile(files[0])
|
||||||
|
upload.upload().then((data) => {
|
||||||
|
setFileUpload({
|
||||||
|
file: files[0],
|
||||||
|
type: imageType ? 'image' : 'pdf',
|
||||||
|
id: data?.successful?.[0]?.response?.body?.id,
|
||||||
|
name: data?.successful?.[0]?.response?.body?.filename,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
setDragOver(false)
|
setDragOver(false)
|
||||||
},
|
},
|
||||||
onDropRejected: (e) => {
|
onDropRejected: (e) => {
|
||||||
if (
|
if (activeAssistant?.tools && !activeAssistant?.tools[0]?.enabled) {
|
||||||
activeThread?.assistants[0].tools &&
|
|
||||||
!activeThread?.assistants[0].tools[0]?.enabled
|
|
||||||
) {
|
|
||||||
setDragRejected({ code: 'retrieval-off' })
|
setDragRejected({ code: 'retrieval-off' })
|
||||||
} else {
|
} else {
|
||||||
setDragRejected({ code: e[0].errors[0].code })
|
setDragRejected({ code: e[0].errors[0].code })
|
||||||
@ -186,8 +197,7 @@ const ThreadCenterPanel = () => {
|
|||||||
<h6 className="font-bold">
|
<h6 className="font-bold">
|
||||||
{isDragReject
|
{isDragReject
|
||||||
? `Currently, we only support 1 attachment at the same time with ${
|
? `Currently, we only support 1 attachment at the same time with ${
|
||||||
activeThread?.assistants[0].model.settings
|
activeAssistant?.model.settings?.vision_model
|
||||||
?.vision_model
|
|
||||||
? 'PDF, JPEG, JPG, PNG'
|
? 'PDF, JPEG, JPG, PNG'
|
||||||
: 'PDF'
|
: 'PDF'
|
||||||
} format`
|
} format`
|
||||||
@ -195,7 +205,7 @@ const ThreadCenterPanel = () => {
|
|||||||
</h6>
|
</h6>
|
||||||
{!isDragReject && (
|
{!isDragReject && (
|
||||||
<p className="mt-2">
|
<p className="mt-2">
|
||||||
{activeThread?.assistants[0].model.settings?.vision_model
|
{activeAssistant?.model.settings?.vision_model
|
||||||
? 'PDF, JPEG, JPG, PNG'
|
? 'PDF, JPEG, JPG, PNG'
|
||||||
: 'PDF'}
|
: 'PDF'}
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@ -15,13 +15,15 @@ const ModalEditTitleThread = () => {
|
|||||||
const [modalActionThread, setModalActionThread] = useAtom(
|
const [modalActionThread, setModalActionThread] = useAtom(
|
||||||
modalActionThreadAtom
|
modalActionThreadAtom
|
||||||
)
|
)
|
||||||
const [title, setTitle] = useState(modalActionThread.thread?.title as string)
|
const [title, setTitle] = useState(
|
||||||
|
modalActionThread.thread?.metadata?.title as string
|
||||||
|
)
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
if (modalActionThread.thread?.title) {
|
if (modalActionThread.thread?.metadata?.title) {
|
||||||
setTitle(modalActionThread.thread?.title)
|
setTitle(modalActionThread.thread?.metadata?.title as string)
|
||||||
}
|
}
|
||||||
}, [modalActionThread.thread?.title])
|
}, [modalActionThread.thread?.metadata])
|
||||||
|
|
||||||
const onUpdateTitle = useCallback(
|
const onUpdateTitle = useCallback(
|
||||||
(e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
|
(e: React.MouseEvent<HTMLButtonElement, MouseEvent>) => {
|
||||||
@ -30,6 +32,10 @@ const ModalEditTitleThread = () => {
|
|||||||
updateThreadMetadata({
|
updateThreadMetadata({
|
||||||
...modalActionThread?.thread,
|
...modalActionThread?.thread,
|
||||||
title: title || 'New Thread',
|
title: title || 'New Thread',
|
||||||
|
metadata: {
|
||||||
|
...modalActionThread?.thread.metadata,
|
||||||
|
title: title || 'New Thread',
|
||||||
|
},
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[modalActionThread?.thread, title, updateThreadMetadata]
|
[modalActionThread?.thread, title, updateThreadMetadata]
|
||||||
|
|||||||
@ -20,7 +20,10 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
|||||||
import useRecommendedModel from '@/hooks/useRecommendedModel'
|
import useRecommendedModel from '@/hooks/useRecommendedModel'
|
||||||
import useSetActiveThread from '@/hooks/useSetActiveThread'
|
import useSetActiveThread from '@/hooks/useSetActiveThread'
|
||||||
|
|
||||||
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
import {
|
||||||
|
activeAssistantAtom,
|
||||||
|
assistantsAtom,
|
||||||
|
} from '@/helpers/atoms/Assistant.atom'
|
||||||
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
|
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -34,6 +37,7 @@ import {
|
|||||||
const ThreadLeftPanel = () => {
|
const ThreadLeftPanel = () => {
|
||||||
const threads = useAtomValue(threadsAtom)
|
const threads = useAtomValue(threadsAtom)
|
||||||
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
|
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const { setActiveThread } = useSetActiveThread()
|
const { setActiveThread } = useSetActiveThread()
|
||||||
const assistants = useAtomValue(assistantsAtom)
|
const assistants = useAtomValue(assistantsAtom)
|
||||||
const threadDataReady = useAtomValue(threadDataReadyAtom)
|
const threadDataReady = useAtomValue(threadDataReadyAtom)
|
||||||
@ -67,6 +71,7 @@ const ThreadLeftPanel = () => {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (
|
if (
|
||||||
threadDataReady &&
|
threadDataReady &&
|
||||||
|
activeAssistant &&
|
||||||
assistants.length > 0 &&
|
assistants.length > 0 &&
|
||||||
threads.length === 0 &&
|
threads.length === 0 &&
|
||||||
downloadedModels.length > 0
|
downloadedModels.length > 0
|
||||||
@ -75,7 +80,10 @@ const ThreadLeftPanel = () => {
|
|||||||
(model) => model.engine === InferenceEngine.cortex_llamacpp
|
(model) => model.engine === InferenceEngine.cortex_llamacpp
|
||||||
)
|
)
|
||||||
const selectedModel = model[0] || recommendedModel
|
const selectedModel = model[0] || recommendedModel
|
||||||
requestCreateNewThread(assistants[0], selectedModel)
|
requestCreateNewThread(
|
||||||
|
{ ...assistants[0], ...activeAssistant },
|
||||||
|
selectedModel
|
||||||
|
)
|
||||||
} else if (threadDataReady && !activeThreadId) {
|
} else if (threadDataReady && !activeThreadId) {
|
||||||
setActiveThread(threads[0])
|
setActiveThread(threads[0])
|
||||||
}
|
}
|
||||||
@ -88,6 +96,7 @@ const ThreadLeftPanel = () => {
|
|||||||
setActiveThread,
|
setActiveThread,
|
||||||
recommendedModel,
|
recommendedModel,
|
||||||
downloadedModels,
|
downloadedModels,
|
||||||
|
activeAssistant,
|
||||||
])
|
])
|
||||||
|
|
||||||
const onContextMenu = (event: React.MouseEvent, thread: Thread) => {
|
const onContextMenu = (event: React.MouseEvent, thread: Thread) => {
|
||||||
@ -138,7 +147,7 @@ const ThreadLeftPanel = () => {
|
|||||||
activeThreadId && 'font-medium'
|
activeThreadId && 'font-medium'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{thread.title}
|
{thread.title ?? thread.metadata?.title}
|
||||||
</h1>
|
</h1>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
|
|||||||
@ -14,48 +14,54 @@ import AssistantSetting from '@/screens/Thread/ThreadCenterPanel/AssistantSettin
|
|||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
|
|
||||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
const Tools = () => {
|
const Tools = () => {
|
||||||
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||||
const { updateThreadMetadata } = useCreateNewThread()
|
const { updateThreadMetadata } = useCreateNewThread()
|
||||||
const { recommendedModel, downloadedModels } = useRecommendedModel()
|
const { recommendedModel, downloadedModels } = useRecommendedModel()
|
||||||
|
|
||||||
const componentDataAssistantSetting = getConfigurationsData(
|
const componentDataAssistantSetting = getConfigurationsData(
|
||||||
(activeThread?.assistants[0]?.tools &&
|
(activeAssistant?.tools && activeAssistant?.tools[0]?.settings) ?? {}
|
||||||
activeThread?.assistants[0]?.tools[0]?.settings) ??
|
|
||||||
{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!activeThread) return
|
if (!activeThread) return
|
||||||
let model = downloadedModels.find(
|
let model = downloadedModels.find(
|
||||||
(model) => model.id === activeThread.assistants[0].model.id
|
(model) => model.id === activeAssistant?.model.id
|
||||||
)
|
)
|
||||||
if (!model) {
|
if (!model) {
|
||||||
model = recommendedModel
|
model = recommendedModel
|
||||||
}
|
}
|
||||||
setSelectedModel(model)
|
setSelectedModel(model)
|
||||||
}, [recommendedModel, activeThread, downloadedModels, setSelectedModel])
|
}, [
|
||||||
|
recommendedModel,
|
||||||
|
activeThread,
|
||||||
|
downloadedModels,
|
||||||
|
setSelectedModel,
|
||||||
|
activeAssistant?.model.id,
|
||||||
|
])
|
||||||
|
|
||||||
const onRetrievalSwitchUpdate = useCallback(
|
const onRetrievalSwitchUpdate = useCallback(
|
||||||
(enabled: boolean) => {
|
(enabled: boolean) => {
|
||||||
if (!activeThread) return
|
if (!activeThread || !activeAssistant) return
|
||||||
updateThreadMetadata({
|
updateThreadMetadata({
|
||||||
...activeThread,
|
...activeThread,
|
||||||
assistants: [
|
assistants: [
|
||||||
{
|
{
|
||||||
...activeThread.assistants[0],
|
...activeAssistant,
|
||||||
tools: [
|
tools: [
|
||||||
{
|
{
|
||||||
type: 'retrieval',
|
type: 'retrieval',
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
settings:
|
settings:
|
||||||
(activeThread.assistants[0].tools &&
|
(activeAssistant.tools &&
|
||||||
activeThread.assistants[0].tools[0]?.settings) ??
|
activeAssistant.tools[0]?.settings) ??
|
||||||
{},
|
{},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -63,25 +69,25 @@ const Tools = () => {
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[activeThread, updateThreadMetadata]
|
[activeAssistant, activeThread, updateThreadMetadata]
|
||||||
)
|
)
|
||||||
|
|
||||||
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
|
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
|
||||||
(enabled: boolean) => {
|
(enabled: boolean) => {
|
||||||
if (!activeThread) return
|
if (!activeThread || !activeAssistant) return
|
||||||
updateThreadMetadata({
|
updateThreadMetadata({
|
||||||
...activeThread,
|
...activeThread,
|
||||||
assistants: [
|
assistants: [
|
||||||
{
|
{
|
||||||
...activeThread.assistants[0],
|
...activeAssistant,
|
||||||
tools: [
|
tools: [
|
||||||
{
|
{
|
||||||
type: 'retrieval',
|
type: 'retrieval',
|
||||||
enabled: true,
|
enabled: true,
|
||||||
useTimeWeightedRetriever: enabled,
|
useTimeWeightedRetriever: enabled,
|
||||||
settings:
|
settings:
|
||||||
(activeThread.assistants[0].tools &&
|
(activeAssistant.tools &&
|
||||||
activeThread.assistants[0].tools[0]?.settings) ??
|
activeAssistant.tools[0]?.settings) ??
|
||||||
{},
|
{},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -89,15 +95,14 @@ const Tools = () => {
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[activeThread, updateThreadMetadata]
|
[activeAssistant, activeThread, updateThreadMetadata]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!experimentalFeature) return null
|
if (!experimentalFeature) return null
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Fragment>
|
<Fragment>
|
||||||
{activeThread?.assistants[0]?.tools &&
|
{activeAssistant?.tools && componentDataAssistantSetting.length > 0 && (
|
||||||
componentDataAssistantSetting.length > 0 && (
|
|
||||||
<div className="p-4">
|
<div className="p-4">
|
||||||
<div className="mb-2">
|
<div className="mb-2">
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
@ -122,13 +127,13 @@ const Tools = () => {
|
|||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<Switch
|
<Switch
|
||||||
name="retrieval"
|
name="retrieval"
|
||||||
checked={activeThread?.assistants[0].tools[0].enabled}
|
checked={activeAssistant?.tools[0].enabled}
|
||||||
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
|
onChange={(e) => onRetrievalSwitchUpdate(e.target.checked)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{activeThread?.assistants[0]?.tools[0].enabled && (
|
{activeAssistant?.tools[0].enabled && (
|
||||||
<div className="pb-4 pt-2">
|
<div className="pb-4 pt-2">
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
<div className="item-center mb-2 flex">
|
<div className="item-center mb-2 flex">
|
||||||
@ -155,11 +160,7 @@ const Tools = () => {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<Input
|
<Input value={selectedModel?.name || ''} disabled readOnly />
|
||||||
value={selectedModel?.name || ''}
|
|
||||||
disabled
|
|
||||||
readOnly
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
@ -214,8 +215,8 @@ const Tools = () => {
|
|||||||
<Switch
|
<Switch
|
||||||
name="use-time-weighted-retriever"
|
name="use-time-weighted-retriever"
|
||||||
checked={
|
checked={
|
||||||
activeThread?.assistants[0].tools[0]
|
activeAssistant?.tools[0].useTimeWeightedRetriever ||
|
||||||
.useTimeWeightedRetriever || false
|
false
|
||||||
}
|
}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
||||||
@ -224,9 +225,7 @@ const Tools = () => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<AssistantSetting
|
<AssistantSetting componentData={componentDataAssistantSetting} />
|
||||||
componentData={componentDataAssistantSetting}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -38,6 +38,7 @@ import PromptTemplateSetting from './PromptTemplateSetting'
|
|||||||
import Tools from './Tools'
|
import Tools from './Tools'
|
||||||
|
|
||||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
activeThreadAtom,
|
activeThreadAtom,
|
||||||
@ -53,6 +54,7 @@ const ENGINE_SETTINGS = 'Engine Settings'
|
|||||||
|
|
||||||
const ThreadRightPanel = () => {
|
const ThreadRightPanel = () => {
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||||
const selectedModel = useAtomValue(selectedModelAtom)
|
const selectedModel = useAtomValue(selectedModelAtom)
|
||||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||||
@ -154,18 +156,18 @@ const ThreadRightPanel = () => {
|
|||||||
|
|
||||||
const onAssistantInstructionChanged = useCallback(
|
const onAssistantInstructionChanged = useCallback(
|
||||||
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
if (activeThread)
|
if (activeThread && activeAssistant)
|
||||||
updateThreadMetadata({
|
updateThreadMetadata({
|
||||||
...activeThread,
|
...activeThread,
|
||||||
assistants: [
|
assistants: [
|
||||||
{
|
{
|
||||||
...activeThread.assistants[0],
|
...activeAssistant,
|
||||||
instructions: e.target.value || '',
|
instructions: e.target.value || '',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[activeThread, updateThreadMetadata]
|
[activeAssistant, activeThread, updateThreadMetadata]
|
||||||
)
|
)
|
||||||
|
|
||||||
const resetModel = useDebouncedCallback(() => {
|
const resetModel = useDebouncedCallback(() => {
|
||||||
@ -174,9 +176,7 @@ const ThreadRightPanel = () => {
|
|||||||
|
|
||||||
const onValueChanged = useCallback(
|
const onValueChanged = useCallback(
|
||||||
(key: string, value: string | number | boolean | string[]) => {
|
(key: string, value: string | number | boolean | string[]) => {
|
||||||
if (!activeThread) {
|
if (!activeThread || !activeAssistant) return
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
setEngineParamsUpdate(true)
|
setEngineParamsUpdate(true)
|
||||||
resetModel()
|
resetModel()
|
||||||
@ -186,32 +186,38 @@ const ThreadRightPanel = () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (
|
if (
|
||||||
activeThread.assistants[0].model.parameters?.max_tokens &&
|
activeAssistant.model.parameters?.max_tokens &&
|
||||||
activeThread.assistants[0].model.settings?.ctx_len
|
activeAssistant.model.settings?.ctx_len
|
||||||
) {
|
) {
|
||||||
if (
|
if (
|
||||||
key === 'max_tokens' &&
|
key === 'max_tokens' &&
|
||||||
Number(value) > activeThread.assistants[0].model.settings.ctx_len
|
Number(value) > activeAssistant.model.settings.ctx_len
|
||||||
) {
|
) {
|
||||||
updateModelParameter(activeThread, {
|
updateModelParameter(activeThread, {
|
||||||
params: {
|
params: {
|
||||||
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
|
max_tokens: activeAssistant.model.settings.ctx_len,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if (
|
if (
|
||||||
key === 'ctx_len' &&
|
key === 'ctx_len' &&
|
||||||
Number(value) < activeThread.assistants[0].model.parameters.max_tokens
|
Number(value) < activeAssistant.model.parameters.max_tokens
|
||||||
) {
|
) {
|
||||||
updateModelParameter(activeThread, {
|
updateModelParameter(activeThread, {
|
||||||
params: {
|
params: {
|
||||||
max_tokens: activeThread.assistants[0].model.settings.ctx_len,
|
max_tokens: activeAssistant.model.settings.ctx_len,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[activeThread, resetModel, setEngineParamsUpdate, updateModelParameter]
|
[
|
||||||
|
activeAssistant,
|
||||||
|
activeThread,
|
||||||
|
resetModel,
|
||||||
|
setEngineParamsUpdate,
|
||||||
|
updateModelParameter,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!activeThread) {
|
if (!activeThread) {
|
||||||
@ -250,7 +256,7 @@ const ThreadRightPanel = () => {
|
|||||||
<TextArea
|
<TextArea
|
||||||
id="assistant-instructions"
|
id="assistant-instructions"
|
||||||
placeholder="Eg. You are a helpful assistant."
|
placeholder="Eg. You are a helpful assistant."
|
||||||
value={activeThread?.assistants[0].instructions ?? ''}
|
value={activeAssistant?.instructions ?? ''}
|
||||||
autoResize
|
autoResize
|
||||||
onChange={onAssistantInstructionChanged}
|
onChange={onAssistantInstructionChanged}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@ -12,6 +12,9 @@ global.ResizeObserver = class {
|
|||||||
// Mock the useStarterScreen hook
|
// Mock the useStarterScreen hook
|
||||||
jest.mock('@/hooks/useStarterScreen')
|
jest.mock('@/hooks/useStarterScreen')
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
global.API_BASE_URL = 'http://localhost:3000'
|
||||||
|
|
||||||
describe('ThreadScreen', () => {
|
describe('ThreadScreen', () => {
|
||||||
it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
|
it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
|
||||||
; (useStarterScreen as jest.Mock).mockReturnValue({
|
; (useStarterScreen as jest.Mock).mockReturnValue({
|
||||||
|
|||||||
2
web/types/file.d.ts
vendored
2
web/types/file.d.ts
vendored
@ -3,4 +3,6 @@ export type FileType = 'image' | 'pdf'
|
|||||||
export type FileInfo = {
|
export type FileInfo = {
|
||||||
file: File
|
file: File
|
||||||
type: FileType
|
type: FileType
|
||||||
|
id?: string
|
||||||
|
name?: string
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import { baseName } from '@janhq/core'
|
import { baseName } from '@janhq/core'
|
||||||
|
import Uppy from '@uppy/core'
|
||||||
|
import XHR from '@uppy/xhr-upload'
|
||||||
|
|
||||||
export type FilePathWithSize = {
|
export type FilePathWithSize = {
|
||||||
path: string
|
path: string
|
||||||
@ -27,3 +29,21 @@ export const getFileInfoFromFile = async (
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function creates an Uppy instance with XHR plugin for file upload to the server.
|
||||||
|
* @returns Uppy instance
|
||||||
|
*/
|
||||||
|
export const uploader = () => {
|
||||||
|
const uppy = new Uppy().use(XHR, {
|
||||||
|
endpoint: `${API_BASE_URL}/v1/files`,
|
||||||
|
method: 'POST',
|
||||||
|
fieldName: 'file',
|
||||||
|
formData: true,
|
||||||
|
limit: 1,
|
||||||
|
})
|
||||||
|
uppy.setMeta({
|
||||||
|
purpose: 'assistants',
|
||||||
|
})
|
||||||
|
return uppy
|
||||||
|
}
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import { ulid } from 'ulidx'
|
|||||||
|
|
||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
|
|
||||||
import { FileType } from '@/types/file'
|
import { FileInfo, FileType } from '@/types/file'
|
||||||
|
|
||||||
export class MessageRequestBuilder {
|
export class MessageRequestBuilder {
|
||||||
msgId: string
|
msgId: string
|
||||||
@ -38,7 +38,7 @@ export class MessageRequestBuilder {
|
|||||||
.filter((e) => e.status !== MessageStatus.Error)
|
.filter((e) => e.status !== MessageStatus.Error)
|
||||||
.map<ChatCompletionMessage>((msg) => ({
|
.map<ChatCompletionMessage>((msg) => ({
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
content: msg.content[0]?.text.value ?? '.',
|
content: msg.content[0]?.text?.value ?? '.',
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,11 +46,11 @@ export class MessageRequestBuilder {
|
|||||||
pushMessage(
|
pushMessage(
|
||||||
message: string,
|
message: string,
|
||||||
base64Blob: string | undefined,
|
base64Blob: string | undefined,
|
||||||
fileContentType: FileType
|
fileInfo?: FileInfo
|
||||||
) {
|
) {
|
||||||
if (base64Blob && fileContentType === 'pdf')
|
if (base64Blob && fileInfo?.type === 'pdf')
|
||||||
return this.addDocMessage(message)
|
return this.addDocMessage(message, fileInfo?.name)
|
||||||
else if (base64Blob && fileContentType === 'image') {
|
else if (base64Blob && fileInfo?.type === 'image') {
|
||||||
return this.addImageMessage(message, base64Blob)
|
return this.addImageMessage(message, base64Blob)
|
||||||
}
|
}
|
||||||
this.messages = [
|
this.messages = [
|
||||||
@ -77,7 +77,7 @@ export class MessageRequestBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Chainable
|
// Chainable
|
||||||
addDocMessage(prompt: string) {
|
addDocMessage(prompt: string, name?: string) {
|
||||||
const message: ChatCompletionMessage = {
|
const message: ChatCompletionMessage = {
|
||||||
role: ChatCompletionRole.User,
|
role: ChatCompletionRole.User,
|
||||||
content: [
|
content: [
|
||||||
@ -88,7 +88,7 @@ export class MessageRequestBuilder {
|
|||||||
{
|
{
|
||||||
type: ChatCompletionMessageContentType.Doc,
|
type: ChatCompletionMessageContentType.Doc,
|
||||||
doc_url: {
|
doc_url: {
|
||||||
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
|
url: name ?? `${this.msgId}.pdf`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
] as ChatCompletionMessageContent,
|
] as ChatCompletionMessageContent,
|
||||||
@ -163,6 +163,7 @@ export class MessageRequestBuilder {
|
|||||||
return {
|
return {
|
||||||
id: this.msgId,
|
id: this.msgId,
|
||||||
type: this.type,
|
type: this.type,
|
||||||
|
attachments: [],
|
||||||
threadId: this.thread.id,
|
threadId: this.thread.id,
|
||||||
messages: this.normalizeMessages(this.messages),
|
messages: this.normalizeMessages(this.messages),
|
||||||
model: this.model,
|
model: this.model,
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
|
import {
|
||||||
import { ChatCompletionRole, MessageStatus } from '@janhq/core'
|
ChatCompletionRole,
|
||||||
|
MessageRequestType,
|
||||||
|
MessageStatus,
|
||||||
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { ThreadMessageBuilder } from './threadMessageBuilder'
|
import { ThreadMessageBuilder } from './threadMessageBuilder'
|
||||||
import { MessageRequestBuilder } from './messageRequestBuilder'
|
import { MessageRequestBuilder } from './messageRequestBuilder'
|
||||||
|
|
||||||
import { ContentType } from '@janhq/core';
|
import { ContentType } from '@janhq/core'
|
||||||
describe('ThreadMessageBuilder', () => {
|
describe('ThreadMessageBuilder', () => {
|
||||||
it('testBuildMethod', () => {
|
it('testBuildMethod', () => {
|
||||||
const msgRequest = new MessageRequestBuilder(
|
const msgRequest = new MessageRequestBuilder(
|
||||||
'type',
|
MessageRequestType.Thread,
|
||||||
{ model: 'model' },
|
{ model: 'model' } as any,
|
||||||
{ id: 'thread-id' },
|
{ id: 'thread-id' } as any,
|
||||||
[]
|
[]
|
||||||
)
|
)
|
||||||
const builder = new ThreadMessageBuilder(msgRequest)
|
const builder = new ThreadMessageBuilder(msgRequest)
|
||||||
@ -29,14 +32,14 @@ import { ContentType } from '@janhq/core';
|
|||||||
|
|
||||||
it('testPushMessageWithPromptOnly', () => {
|
it('testPushMessageWithPromptOnly', () => {
|
||||||
const msgRequest = new MessageRequestBuilder(
|
const msgRequest = new MessageRequestBuilder(
|
||||||
'type',
|
MessageRequestType.Thread,
|
||||||
{ model: 'model' },
|
{ model: 'model' } as any,
|
||||||
{ id: 'thread-id' },
|
{ id: 'thread-id' } as any,
|
||||||
[]
|
[]
|
||||||
);
|
)
|
||||||
const builder = new ThreadMessageBuilder(msgRequest);
|
const builder = new ThreadMessageBuilder(msgRequest)
|
||||||
const prompt = 'test prompt';
|
const prompt = 'test prompt'
|
||||||
builder.pushMessage(prompt, undefined, []);
|
builder.pushMessage(prompt, undefined, undefined)
|
||||||
expect(builder.content).toEqual([
|
expect(builder.content).toEqual([
|
||||||
{
|
{
|
||||||
type: ContentType.Text,
|
type: ContentType.Text,
|
||||||
@ -45,56 +48,53 @@ import { ContentType } from '@janhq/core';
|
|||||||
annotations: [],
|
annotations: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]);
|
])
|
||||||
});
|
})
|
||||||
|
|
||||||
|
|
||||||
it('testPushMessageWithPdf', () => {
|
it('testPushMessageWithPdf', () => {
|
||||||
const msgRequest = new MessageRequestBuilder(
|
const msgRequest = new MessageRequestBuilder(
|
||||||
'type',
|
MessageRequestType.Thread,
|
||||||
{ model: 'model' },
|
{ model: 'model' } as any,
|
||||||
{ id: 'thread-id' },
|
{ id: 'thread-id' } as any,
|
||||||
[]
|
[]
|
||||||
);
|
)
|
||||||
const builder = new ThreadMessageBuilder(msgRequest);
|
const builder = new ThreadMessageBuilder(msgRequest)
|
||||||
const prompt = 'test prompt';
|
const prompt = 'test prompt'
|
||||||
const base64 = 'test base64';
|
const base64 = 'test base64'
|
||||||
const fileUpload = [{ type: 'pdf', file: { name: 'test.pdf', size: 1000 } }];
|
const fileUpload = [
|
||||||
builder.pushMessage(prompt, base64, fileUpload);
|
{ type: 'pdf', file: { name: 'test.pdf', size: 1000 } },
|
||||||
|
] as any
|
||||||
|
builder.pushMessage(prompt, base64, fileUpload)
|
||||||
expect(builder.content).toEqual([
|
expect(builder.content).toEqual([
|
||||||
{
|
{
|
||||||
type: ContentType.Pdf,
|
type: ContentType.Text,
|
||||||
text: {
|
text: {
|
||||||
value: prompt,
|
value: prompt,
|
||||||
annotations: [base64],
|
annotations: [],
|
||||||
name: fileUpload[0].file.name,
|
|
||||||
size: fileUpload[0].file.size,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]);
|
])
|
||||||
});
|
})
|
||||||
|
|
||||||
|
|
||||||
it('testPushMessageWithImage', () => {
|
it('testPushMessageWithImage', () => {
|
||||||
const msgRequest = new MessageRequestBuilder(
|
const msgRequest = new MessageRequestBuilder(
|
||||||
'type',
|
MessageRequestType.Thread,
|
||||||
{ model: 'model' },
|
{ model: 'model' } as any,
|
||||||
{ id: 'thread-id' },
|
{ id: 'thread-id' } as any,
|
||||||
[]
|
[]
|
||||||
);
|
)
|
||||||
const builder = new ThreadMessageBuilder(msgRequest);
|
const builder = new ThreadMessageBuilder(msgRequest)
|
||||||
const prompt = 'test prompt';
|
const prompt = 'test prompt'
|
||||||
const base64 = 'test base64';
|
const base64 = 'test base64'
|
||||||
const fileUpload = [{ type: 'image', file: { name: 'test.jpg', size: 1000 } }];
|
const fileUpload = [{ type: 'image', file: { name: 'test.jpg', size: 1000 } }]
|
||||||
builder.pushMessage(prompt, base64, fileUpload);
|
builder.pushMessage(prompt, base64, fileUpload as any)
|
||||||
expect(builder.content).toEqual([
|
expect(builder.content).toEqual([
|
||||||
{
|
{
|
||||||
type: ContentType.Image,
|
type: ContentType.Text,
|
||||||
text: {
|
text: {
|
||||||
value: prompt,
|
value: prompt,
|
||||||
annotations: [base64],
|
annotations: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]);
|
])
|
||||||
});
|
})
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
Attachment,
|
||||||
ChatCompletionRole,
|
ChatCompletionRole,
|
||||||
ContentType,
|
ContentType,
|
||||||
MessageStatus,
|
MessageStatus,
|
||||||
@ -14,6 +15,7 @@ export class ThreadMessageBuilder {
|
|||||||
messageRequest: MessageRequestBuilder
|
messageRequest: MessageRequestBuilder
|
||||||
|
|
||||||
content: ThreadContent[] = []
|
content: ThreadContent[] = []
|
||||||
|
attachments: Attachment[] = []
|
||||||
|
|
||||||
constructor(messageRequest: MessageRequestBuilder) {
|
constructor(messageRequest: MessageRequestBuilder) {
|
||||||
this.messageRequest = messageRequest
|
this.messageRequest = messageRequest
|
||||||
@ -24,6 +26,7 @@ export class ThreadMessageBuilder {
|
|||||||
return {
|
return {
|
||||||
id: this.messageRequest.msgId,
|
id: this.messageRequest.msgId,
|
||||||
thread_id: this.messageRequest.thread.id,
|
thread_id: this.messageRequest.thread.id,
|
||||||
|
attachments: this.attachments,
|
||||||
role: ChatCompletionRole.User,
|
role: ChatCompletionRole.User,
|
||||||
status: MessageStatus.Ready,
|
status: MessageStatus.Ready,
|
||||||
created: timestamp,
|
created: timestamp,
|
||||||
@ -36,31 +39,9 @@ export class ThreadMessageBuilder {
|
|||||||
pushMessage(
|
pushMessage(
|
||||||
prompt: string,
|
prompt: string,
|
||||||
base64: string | undefined,
|
base64: string | undefined,
|
||||||
fileUpload: FileInfo[]
|
fileUpload?: FileInfo
|
||||||
) {
|
) {
|
||||||
if (base64 && fileUpload[0]?.type === 'image') {
|
if (prompt) {
|
||||||
this.content.push({
|
|
||||||
type: ContentType.Image,
|
|
||||||
text: {
|
|
||||||
value: prompt,
|
|
||||||
annotations: [base64],
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if (base64 && fileUpload[0]?.type === 'pdf') {
|
|
||||||
this.content.push({
|
|
||||||
type: ContentType.Pdf,
|
|
||||||
text: {
|
|
||||||
value: prompt,
|
|
||||||
annotations: [base64],
|
|
||||||
name: fileUpload[0].file.name,
|
|
||||||
size: fileUpload[0].file.size,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if (prompt && !base64) {
|
|
||||||
this.content.push({
|
this.content.push({
|
||||||
type: ContentType.Text,
|
type: ContentType.Text,
|
||||||
text: {
|
text: {
|
||||||
@ -69,6 +50,26 @@ export class ThreadMessageBuilder {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
if (base64 && fileUpload?.type === 'image') {
|
||||||
|
this.content.push({
|
||||||
|
type: ContentType.Image,
|
||||||
|
image_url: {
|
||||||
|
url: base64,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (base64 && fileUpload?.type === 'pdf') {
|
||||||
|
this.attachments.push({
|
||||||
|
file_id: fileUpload.id,
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
type: 'file_search',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user