fix: Add hack waiting for model loading

This commit is contained in:
hiro 2023-12-05 22:48:04 +07:00
parent 975e9718bf
commit 9daee14167
2 changed files with 22 additions and 7 deletions

View File

@ -198,7 +198,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
instance: JanInferenceNitroExtension instance: JanInferenceNitroExtension
) { ) {
if (data.model.engine !== 'nitro') { return } if (data.model.engine !== 'nitro') { return }
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),

View File

@ -50,7 +50,6 @@ export default function useSendChatMessage() {
const [queuedMessage, setQueuedMessage] = useState(false) const [queuedMessage, setQueuedMessage] = useState(false)
const modelRef = useRef<Model | undefined>() const modelRef = useRef<Model | undefined>()
useEffect(() => { useEffect(() => {
modelRef.current = activeModel modelRef.current = activeModel
}, [activeModel]) }, [activeModel])
@ -91,19 +90,35 @@ export default function useSendChatMessage() {
id: ulid(), id: ulid(),
messages: messages, messages: messages,
threadId: activeThread.id, threadId: activeThread.id,
model: activeThread.assistants[0].model??selectedModel, model: activeThread.assistants[0].model ?? selectedModel,
} }
const modelId = selectedModel?.id ?? activeThread.assistants[0].model.id const modelId = selectedModel?.id ?? activeThread.assistants[0].model.id
if (activeModel?.id !== modelId) { if (activeModel?.id !== modelId) {
setQueuedMessage(true) setQueuedMessage(true)
await startModel(modelId) startModel(modelId)
await WaitForModelStarting(modelId)
setQueuedMessage(false) setQueuedMessage(false)
} }
events.emit(EventName.OnMessageSent, messageRequest) events.emit(EventName.OnMessageSent, messageRequest)
} }
// TODO: Refactor @louis
const WaitForModelStarting = async (modelId: string) => {
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (modelRef.current?.id !== modelId) {
console.log('waiting for model to start')
await WaitForModelStarting(modelId)
resolve()
} else {
resolve()
}
}, 200)
})
}
const sendChatMessage = async () => { const sendChatMessage = async () => {
if (!currentPrompt || currentPrompt.trim().length === 0) { if (!currentPrompt || currentPrompt.trim().length === 0) {
return return
@ -180,8 +195,7 @@ export default function useSendChatMessage() {
id: msgId, id: msgId,
threadId: activeThread.id, threadId: activeThread.id,
messages, messages,
parameters: activeThread.assistants[0].model.parameters, model: selectedModel ?? activeThread.assistants[0].model,
model: selectedModel??activeThread.assistants[0].model,
} }
const timestamp = Date.now() const timestamp = Date.now()
const threadMessage: ThreadMessage = { const threadMessage: ThreadMessage = {
@ -213,9 +227,11 @@ export default function useSendChatMessage() {
if (activeModel?.id !== modelId) { if (activeModel?.id !== modelId) {
setQueuedMessage(true) setQueuedMessage(true)
await startModel(modelId) startModel(modelId)
await WaitForModelStarting(modelId)
setQueuedMessage(false) setQueuedMessage(false)
} }
console.log('messageRequest', messageRequest)
events.emit(EventName.OnMessageSent, messageRequest) events.emit(EventName.OnMessageSent, messageRequest)
} }