allow pagination for querying hf
Signed-off-by: James <james@jan.ai>
This commit is contained in:
parent
1054c79eb7
commit
dbe2cc848e
@ -3,6 +3,9 @@ const { readdirSync, lstatSync } = require("fs");
|
||||
const { app } = require("electron");
|
||||
const { listModels, listFiles, fileDownloadInfo } = require("@huggingface/hub");
|
||||
|
||||
let modelsIterator = undefined;
|
||||
let currentSearchOwner = undefined;
|
||||
|
||||
const ALL_MODELS = [
|
||||
{
|
||||
id: "llama-2-7b-chat.Q4_K_M.gguf.bin",
|
||||
@ -88,28 +91,55 @@ function getDownloadedModels() {
|
||||
return downloadedModels;
|
||||
}
|
||||
|
||||
const searchHfModels = async (params) => {
|
||||
const result = [];
|
||||
const getNextModels = async (count) => {
|
||||
const models = [];
|
||||
let hasMore = true;
|
||||
|
||||
var index = 0;
|
||||
while (models.length < count) {
|
||||
const next = await modelsIterator.next();
|
||||
|
||||
for await (const model of listModels({
|
||||
search: params.search,
|
||||
credentials: params.credentials,
|
||||
})) {
|
||||
// end if we reached the end
|
||||
if (next.done) {
|
||||
hasMore = false;
|
||||
break;
|
||||
}
|
||||
|
||||
const model = next.value;
|
||||
const files = await listFilesByName(model.name);
|
||||
result.push({
|
||||
|
||||
models.push({
|
||||
...model,
|
||||
files,
|
||||
});
|
||||
|
||||
index++;
|
||||
if (index === params.limit) break;
|
||||
}
|
||||
|
||||
const result = {
|
||||
data: models,
|
||||
hasMore,
|
||||
};
|
||||
return result;
|
||||
};
|
||||
|
||||
const searchHfModels = async (params) => {
|
||||
if (currentSearchOwner === params.search.owner && modelsIterator != null) {
|
||||
// paginated search
|
||||
console.debug(`Paginated search owner: ${params.search.owner}`);
|
||||
const models = await getNextModels(params.limit);
|
||||
return models;
|
||||
} else {
|
||||
// new search
|
||||
console.debug(`Init new search owner: ${params.search.owner}`);
|
||||
currentSearchOwner = params.search.owner;
|
||||
modelsIterator = listModels({
|
||||
search: params.search,
|
||||
credentials: params.credentials,
|
||||
});
|
||||
|
||||
const models = await getNextModels(params.limit);
|
||||
return models;
|
||||
}
|
||||
};
|
||||
|
||||
const listFilesByName = async (modelName) => {
|
||||
const repo = { type: "model", name: modelName };
|
||||
const fileDownloadInfoMap = {};
|
||||
|
||||
@ -54,6 +54,7 @@
|
||||
"electron-store": "^8.1.0",
|
||||
"electron-updater": "^6.1.4",
|
||||
"pacote": "^17.0.4",
|
||||
"react-intersection-observer": "^9.5.2",
|
||||
"request": "^2.88.2",
|
||||
"request-progress": "^3.0.0"
|
||||
},
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, { useCallback, useRef, useState } from "react";
|
||||
import React, { useCallback, useRef, useState, useEffect } from "react";
|
||||
import ChatItem from "../ChatItem";
|
||||
import { ChatMessage } from "@/_models/ChatMessage";
|
||||
import useChatMessages from "@/_hooks/useChatMessages";
|
||||
@ -16,8 +16,8 @@ const ChatBody: React.FC = () => {
|
||||
const messageList = useAtomValue(
|
||||
selectAtom(
|
||||
chatMessages,
|
||||
useCallback((v) => v[activeConversationId], [activeConversationId])
|
||||
)
|
||||
useCallback((v) => v[activeConversationId], [activeConversationId]),
|
||||
),
|
||||
);
|
||||
const [content, setContent] = useState<React.JSX.Element[]>([]);
|
||||
|
||||
@ -40,10 +40,10 @@ const ChatBody: React.FC = () => {
|
||||
|
||||
if (message) intersectObs.current.observe(message);
|
||||
},
|
||||
[loading, hasMore]
|
||||
[loading, hasMore],
|
||||
);
|
||||
|
||||
React.useEffect(() => {
|
||||
useEffect(() => {
|
||||
const list = messageList?.map((message, index) => {
|
||||
if (messageList?.length === index + 1) {
|
||||
return (
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import ExploreModelItemHeader from "../ExploreModelItemHeader";
|
||||
import ModelVersionList from "../ModelVersionList";
|
||||
import { Fragment, useState } from "react";
|
||||
import { Fragment, forwardRef, useState } from "react";
|
||||
import SimpleTag, { TagType } from "../SimpleTag";
|
||||
import { displayDate } from "@/_utils/datetime";
|
||||
import { Product } from "@/_models/Product";
|
||||
@ -11,11 +11,16 @@ type Props = {
|
||||
model: Product;
|
||||
};
|
||||
|
||||
const ExploreModelItem: React.FC<Props> = ({ model }) => {
|
||||
export type Ref = HTMLDivElement;
|
||||
|
||||
const ExploreModelItem = forwardRef<Ref, Props>(({ model }, ref) => {
|
||||
const [show, setShow] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col border border-gray-200 rounded-md mb-4">
|
||||
<div
|
||||
ref={ref}
|
||||
className="flex flex-col border border-gray-200 rounded-md mb-4"
|
||||
>
|
||||
<ExploreModelItemHeader
|
||||
name={model.name}
|
||||
status={TagType.Recommended}
|
||||
@ -94,6 +99,6 @@ const ExploreModelItem: React.FC<Props> = ({ model }) => {
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
export default ExploreModelItem;
|
||||
|
||||
@ -3,19 +3,35 @@ import ExploreModelItem from "../ExploreModelItem";
|
||||
import { modelSearchAtom } from "@/_helpers/JotaiWrapper";
|
||||
import useGetHuggingFaceModel from "@/_hooks/useGetHuggingFaceModel";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { useInView } from "react-intersection-observer";
|
||||
|
||||
const ExploreModelList: React.FC = () => {
|
||||
const modelSearch = useAtomValue(modelSearchAtom);
|
||||
const { modelList, getHuggingFaceModel } = useGetHuggingFaceModel();
|
||||
const { ref, inView } = useInView({
|
||||
threshold: 0,
|
||||
triggerOnce: true,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
getHuggingFaceModel(modelSearch);
|
||||
}, [modelSearch]);
|
||||
|
||||
useEffect(() => {
|
||||
if (inView) {
|
||||
console.debug("Load more models..");
|
||||
getHuggingFaceModel(modelSearch);
|
||||
}
|
||||
}, [inView]);
|
||||
|
||||
return (
|
||||
<div className="flex-1 overflow-y-auto scroll">
|
||||
{modelList.map((item) => (
|
||||
<ExploreModelItem key={item.id} model={item} />
|
||||
{modelList.map((item, index) => (
|
||||
<ExploreModelItem
|
||||
ref={index === modelList.length - 1 ? ref : null}
|
||||
key={item.id}
|
||||
model={item}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -17,7 +17,6 @@ const ModelDownloadingRow: React.FC<Props> = ({ downloadState }) => (
|
||||
>
|
||||
<td className="flex flex-col whitespace-nowrap px-6 py-4 text-sm font-medium text-gray-900">
|
||||
{downloadState.fileName}
|
||||
<span className="text-gray-500 font-normal">model.version</span>
|
||||
</td>
|
||||
<td className="whitespace-nowrap px-6 py-4 text-sm text-gray-500">
|
||||
{toGigabytes(downloadState.size.transferred)}
|
||||
|
||||
@ -3,7 +3,6 @@ import { useEffect, useState } from "react";
|
||||
import { executeSerial } from "../../../electron/core/plugin-manager/execution/extension-manager";
|
||||
import { DataService, ModelManagementService } from "../../shared/coreService";
|
||||
import { SearchModelParamHf } from "@/_models/hf/SearchModelParam.hf";
|
||||
import { ListModelOutputHf } from "@/_models/hf/ListModelOutput.hf";
|
||||
|
||||
export function useGetDownloadedModels() {
|
||||
const [downloadedModels, setDownloadedModels] = useState<Product[]>([]);
|
||||
@ -33,13 +32,13 @@ export async function getModelFiles(): Promise<Product[]> {
|
||||
|
||||
export async function searchHfModels(
|
||||
params: SearchModelParamHf
|
||||
): Promise<Product[]> {
|
||||
const result: ListModelOutputHf[] = await executeSerial(
|
||||
): Promise<QueryProductResult> {
|
||||
const result = await executeSerial(
|
||||
ModelManagementService.SEARCH_HF_MODELS,
|
||||
params
|
||||
);
|
||||
|
||||
const products: Product[] = result.map((model) => {
|
||||
const products: Product[] = result.data.map((model: any) => {
|
||||
const modelVersions: ModelVersion[] = [];
|
||||
|
||||
for (const [, file] of Object.entries(model.files)) {
|
||||
@ -78,5 +77,14 @@ export async function searchHfModels(
|
||||
return p;
|
||||
});
|
||||
|
||||
return products;
|
||||
return {
|
||||
data: products,
|
||||
hasMore: result.hasMore,
|
||||
};
|
||||
}
|
||||
|
||||
// TODO define somewhere else
|
||||
export type QueryProductResult = {
|
||||
data: Product[];
|
||||
hasMore: boolean;
|
||||
};
|
||||
|
||||
@ -5,6 +5,9 @@ import { Product } from "@/_models/Product";
|
||||
|
||||
export default function useGetHuggingFaceModel() {
|
||||
const [modelList, setModelList] = useState<Product[]>([]);
|
||||
const [currentOwner, setCurrentOwner] = useState<string | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
const getHuggingFaceModel = async (owner?: string) => {
|
||||
if (!owner) {
|
||||
@ -14,11 +17,16 @@ export default function useGetHuggingFaceModel() {
|
||||
|
||||
const searchParams: SearchModelParamHf = {
|
||||
search: { owner },
|
||||
limit: 10,
|
||||
limit: 5,
|
||||
};
|
||||
const result = await searchHfModels(searchParams);
|
||||
console.debug("result", JSON.stringify(result));
|
||||
setModelList(result);
|
||||
if (owner !== currentOwner) {
|
||||
setModelList(result.data);
|
||||
setCurrentOwner(owner);
|
||||
} else {
|
||||
setModelList([...modelList, ...result.data]);
|
||||
}
|
||||
};
|
||||
|
||||
return { modelList, getHuggingFaceModel };
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user