allow pagination for querying hf

Signed-off-by: James <james@jan.ai>
This commit is contained in:
James 2023-10-05 01:09:57 -07:00 committed by Louis
parent 1054c79eb7
commit dbe2cc848e
9 changed files with 8073 additions and 30 deletions

View File

@ -3,6 +3,9 @@ const { readdirSync, lstatSync } = require("fs");
const { app } = require("electron");
const { listModels, listFiles, fileDownloadInfo } = require("@huggingface/hub");
let modelsIterator = undefined;
let currentSearchOwner = undefined;
const ALL_MODELS = [
{
id: "llama-2-7b-chat.Q4_K_M.gguf.bin",
@ -88,28 +91,55 @@ function getDownloadedModels() {
return downloadedModels;
}
const searchHfModels = async (params) => {
const result = [];
const getNextModels = async (count) => {
const models = [];
let hasMore = true;
var index = 0;
while (models.length < count) {
const next = await modelsIterator.next();
for await (const model of listModels({
search: params.search,
credentials: params.credentials,
})) {
// end if we reached the end
if (next.done) {
hasMore = false;
break;
}
const model = next.value;
const files = await listFilesByName(model.name);
result.push({
models.push({
...model,
files,
});
index++;
if (index === params.limit) break;
}
const result = {
data: models,
hasMore,
};
return result;
};
const searchHfModels = async (params) => {
if (currentSearchOwner === params.search.owner && modelsIterator != null) {
// paginated search
console.debug(`Paginated search owner: ${params.search.owner}`);
const models = await getNextModels(params.limit);
return models;
} else {
// new search
console.debug(`Init new search owner: ${params.search.owner}`);
currentSearchOwner = params.search.owner;
modelsIterator = listModels({
search: params.search,
credentials: params.credentials,
});
const models = await getNextModels(params.limit);
return models;
}
};
const listFilesByName = async (modelName) => {
const repo = { type: "model", name: modelName };
const fileDownloadInfoMap = {};

View File

@ -54,6 +54,7 @@
"electron-store": "^8.1.0",
"electron-updater": "^6.1.4",
"pacote": "^17.0.4",
"react-intersection-observer": "^9.5.2",
"request": "^2.88.2",
"request-progress": "^3.0.0"
},

View File

@ -1,6 +1,6 @@
"use client";
import React, { useCallback, useRef, useState } from "react";
import React, { useCallback, useRef, useState, useEffect } from "react";
import ChatItem from "../ChatItem";
import { ChatMessage } from "@/_models/ChatMessage";
import useChatMessages from "@/_hooks/useChatMessages";
@ -16,8 +16,8 @@ const ChatBody: React.FC = () => {
const messageList = useAtomValue(
selectAtom(
chatMessages,
useCallback((v) => v[activeConversationId], [activeConversationId])
)
useCallback((v) => v[activeConversationId], [activeConversationId]),
),
);
const [content, setContent] = useState<React.JSX.Element[]>([]);
@ -40,10 +40,10 @@ const ChatBody: React.FC = () => {
if (message) intersectObs.current.observe(message);
},
[loading, hasMore]
[loading, hasMore],
);
React.useEffect(() => {
useEffect(() => {
const list = messageList?.map((message, index) => {
if (messageList?.length === index + 1) {
return (

View File

@ -2,7 +2,7 @@
import ExploreModelItemHeader from "../ExploreModelItemHeader";
import ModelVersionList from "../ModelVersionList";
import { Fragment, useState } from "react";
import { Fragment, forwardRef, useState } from "react";
import SimpleTag, { TagType } from "../SimpleTag";
import { displayDate } from "@/_utils/datetime";
import { Product } from "@/_models/Product";
@ -11,11 +11,16 @@ type Props = {
model: Product;
};
const ExploreModelItem: React.FC<Props> = ({ model }) => {
export type Ref = HTMLDivElement;
const ExploreModelItem = forwardRef<Ref, Props>(({ model }, ref) => {
const [show, setShow] = useState(false);
return (
<div className="flex flex-col border border-gray-200 rounded-md mb-4">
<div
ref={ref}
className="flex flex-col border border-gray-200 rounded-md mb-4"
>
<ExploreModelItemHeader
name={model.name}
status={TagType.Recommended}
@ -94,6 +99,6 @@ const ExploreModelItem: React.FC<Props> = ({ model }) => {
)}
</div>
);
};
});
export default ExploreModelItem;

View File

@ -3,19 +3,35 @@ import ExploreModelItem from "../ExploreModelItem";
import { modelSearchAtom } from "@/_helpers/JotaiWrapper";
import useGetHuggingFaceModel from "@/_hooks/useGetHuggingFaceModel";
import { useAtomValue } from "jotai";
import { useInView } from "react-intersection-observer";
const ExploreModelList: React.FC = () => {
const modelSearch = useAtomValue(modelSearchAtom);
const { modelList, getHuggingFaceModel } = useGetHuggingFaceModel();
const { ref, inView } = useInView({
threshold: 0,
triggerOnce: true,
});
useEffect(() => {
getHuggingFaceModel(modelSearch);
}, [modelSearch]);
useEffect(() => {
if (inView) {
console.debug("Load more models..");
getHuggingFaceModel(modelSearch);
}
}, [inView]);
return (
<div className="flex-1 overflow-y-auto scroll">
{modelList.map((item) => (
<ExploreModelItem key={item.id} model={item} />
{modelList.map((item, index) => (
<ExploreModelItem
ref={index === modelList.length - 1 ? ref : null}
key={item.id}
model={item}
/>
))}
</div>
);

View File

@ -17,7 +17,6 @@ const ModelDownloadingRow: React.FC<Props> = ({ downloadState }) => (
>
<td className="flex flex-col whitespace-nowrap px-6 py-4 text-sm font-medium text-gray-900">
{downloadState.fileName}
<span className="text-gray-500 font-normal">model.version</span>
</td>
<td className="whitespace-nowrap px-6 py-4 text-sm text-gray-500">
{toGigabytes(downloadState.size.transferred)}

View File

@ -3,7 +3,6 @@ import { useEffect, useState } from "react";
import { executeSerial } from "../../../electron/core/plugin-manager/execution/extension-manager";
import { DataService, ModelManagementService } from "../../shared/coreService";
import { SearchModelParamHf } from "@/_models/hf/SearchModelParam.hf";
import { ListModelOutputHf } from "@/_models/hf/ListModelOutput.hf";
export function useGetDownloadedModels() {
const [downloadedModels, setDownloadedModels] = useState<Product[]>([]);
@ -33,13 +32,13 @@ export async function getModelFiles(): Promise<Product[]> {
export async function searchHfModels(
params: SearchModelParamHf
): Promise<Product[]> {
const result: ListModelOutputHf[] = await executeSerial(
): Promise<QueryProductResult> {
const result = await executeSerial(
ModelManagementService.SEARCH_HF_MODELS,
params
);
const products: Product[] = result.map((model) => {
const products: Product[] = result.data.map((model: any) => {
const modelVersions: ModelVersion[] = [];
for (const [, file] of Object.entries(model.files)) {
@ -78,5 +77,14 @@ export async function searchHfModels(
return p;
});
return products;
return {
data: products,
hasMore: result.hasMore,
};
}
// TODO define somewhere else
export type QueryProductResult = {
data: Product[];
hasMore: boolean;
};

View File

@ -5,6 +5,9 @@ import { Product } from "@/_models/Product";
export default function useGetHuggingFaceModel() {
const [modelList, setModelList] = useState<Product[]>([]);
const [currentOwner, setCurrentOwner] = useState<string | undefined>(
undefined
);
const getHuggingFaceModel = async (owner?: string) => {
if (!owner) {
@ -14,11 +17,16 @@ export default function useGetHuggingFaceModel() {
const searchParams: SearchModelParamHf = {
search: { owner },
limit: 10,
limit: 5,
};
const result = await searchHfModels(searchParams);
console.debug("result", JSON.stringify(result));
setModelList(result);
if (owner !== currentOwner) {
setModelList(result.data);
setCurrentOwner(owner);
} else {
setModelList([...modelList, ...result.data]);
}
};
return { modelList, getHuggingFaceModel };

7976
yarn.lock Normal file

File diff suppressed because it is too large Load Diff