allow pagination for querying hf
Signed-off-by: James <james@jan.ai>
This commit is contained in:
parent
1054c79eb7
commit
dbe2cc848e
@ -3,6 +3,9 @@ const { readdirSync, lstatSync } = require("fs");
|
|||||||
const { app } = require("electron");
|
const { app } = require("electron");
|
||||||
const { listModels, listFiles, fileDownloadInfo } = require("@huggingface/hub");
|
const { listModels, listFiles, fileDownloadInfo } = require("@huggingface/hub");
|
||||||
|
|
||||||
|
let modelsIterator = undefined;
|
||||||
|
let currentSearchOwner = undefined;
|
||||||
|
|
||||||
const ALL_MODELS = [
|
const ALL_MODELS = [
|
||||||
{
|
{
|
||||||
id: "llama-2-7b-chat.Q4_K_M.gguf.bin",
|
id: "llama-2-7b-chat.Q4_K_M.gguf.bin",
|
||||||
@ -88,28 +91,55 @@ function getDownloadedModels() {
|
|||||||
return downloadedModels;
|
return downloadedModels;
|
||||||
}
|
}
|
||||||
|
|
||||||
const searchHfModels = async (params) => {
|
const getNextModels = async (count) => {
|
||||||
const result = [];
|
const models = [];
|
||||||
|
let hasMore = true;
|
||||||
|
|
||||||
var index = 0;
|
while (models.length < count) {
|
||||||
|
const next = await modelsIterator.next();
|
||||||
|
|
||||||
for await (const model of listModels({
|
// end if we reached the end
|
||||||
search: params.search,
|
if (next.done) {
|
||||||
credentials: params.credentials,
|
hasMore = false;
|
||||||
})) {
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = next.value;
|
||||||
const files = await listFilesByName(model.name);
|
const files = await listFilesByName(model.name);
|
||||||
result.push({
|
|
||||||
|
models.push({
|
||||||
...model,
|
...model,
|
||||||
files,
|
files,
|
||||||
});
|
});
|
||||||
|
|
||||||
index++;
|
|
||||||
if (index === params.limit) break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const result = {
|
||||||
|
data: models,
|
||||||
|
hasMore,
|
||||||
|
};
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const searchHfModels = async (params) => {
|
||||||
|
if (currentSearchOwner === params.search.owner && modelsIterator != null) {
|
||||||
|
// paginated search
|
||||||
|
console.debug(`Paginated search owner: ${params.search.owner}`);
|
||||||
|
const models = await getNextModels(params.limit);
|
||||||
|
return models;
|
||||||
|
} else {
|
||||||
|
// new search
|
||||||
|
console.debug(`Init new search owner: ${params.search.owner}`);
|
||||||
|
currentSearchOwner = params.search.owner;
|
||||||
|
modelsIterator = listModels({
|
||||||
|
search: params.search,
|
||||||
|
credentials: params.credentials,
|
||||||
|
});
|
||||||
|
|
||||||
|
const models = await getNextModels(params.limit);
|
||||||
|
return models;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const listFilesByName = async (modelName) => {
|
const listFilesByName = async (modelName) => {
|
||||||
const repo = { type: "model", name: modelName };
|
const repo = { type: "model", name: modelName };
|
||||||
const fileDownloadInfoMap = {};
|
const fileDownloadInfoMap = {};
|
||||||
|
|||||||
@ -54,6 +54,7 @@
|
|||||||
"electron-store": "^8.1.0",
|
"electron-store": "^8.1.0",
|
||||||
"electron-updater": "^6.1.4",
|
"electron-updater": "^6.1.4",
|
||||||
"pacote": "^17.0.4",
|
"pacote": "^17.0.4",
|
||||||
|
"react-intersection-observer": "^9.5.2",
|
||||||
"request": "^2.88.2",
|
"request": "^2.88.2",
|
||||||
"request-progress": "^3.0.0"
|
"request-progress": "^3.0.0"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import React, { useCallback, useRef, useState } from "react";
|
import React, { useCallback, useRef, useState, useEffect } from "react";
|
||||||
import ChatItem from "../ChatItem";
|
import ChatItem from "../ChatItem";
|
||||||
import { ChatMessage } from "@/_models/ChatMessage";
|
import { ChatMessage } from "@/_models/ChatMessage";
|
||||||
import useChatMessages from "@/_hooks/useChatMessages";
|
import useChatMessages from "@/_hooks/useChatMessages";
|
||||||
@ -16,8 +16,8 @@ const ChatBody: React.FC = () => {
|
|||||||
const messageList = useAtomValue(
|
const messageList = useAtomValue(
|
||||||
selectAtom(
|
selectAtom(
|
||||||
chatMessages,
|
chatMessages,
|
||||||
useCallback((v) => v[activeConversationId], [activeConversationId])
|
useCallback((v) => v[activeConversationId], [activeConversationId]),
|
||||||
)
|
),
|
||||||
);
|
);
|
||||||
const [content, setContent] = useState<React.JSX.Element[]>([]);
|
const [content, setContent] = useState<React.JSX.Element[]>([]);
|
||||||
|
|
||||||
@ -40,10 +40,10 @@ const ChatBody: React.FC = () => {
|
|||||||
|
|
||||||
if (message) intersectObs.current.observe(message);
|
if (message) intersectObs.current.observe(message);
|
||||||
},
|
},
|
||||||
[loading, hasMore]
|
[loading, hasMore],
|
||||||
);
|
);
|
||||||
|
|
||||||
React.useEffect(() => {
|
useEffect(() => {
|
||||||
const list = messageList?.map((message, index) => {
|
const list = messageList?.map((message, index) => {
|
||||||
if (messageList?.length === index + 1) {
|
if (messageList?.length === index + 1) {
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import ExploreModelItemHeader from "../ExploreModelItemHeader";
|
import ExploreModelItemHeader from "../ExploreModelItemHeader";
|
||||||
import ModelVersionList from "../ModelVersionList";
|
import ModelVersionList from "../ModelVersionList";
|
||||||
import { Fragment, useState } from "react";
|
import { Fragment, forwardRef, useState } from "react";
|
||||||
import SimpleTag, { TagType } from "../SimpleTag";
|
import SimpleTag, { TagType } from "../SimpleTag";
|
||||||
import { displayDate } from "@/_utils/datetime";
|
import { displayDate } from "@/_utils/datetime";
|
||||||
import { Product } from "@/_models/Product";
|
import { Product } from "@/_models/Product";
|
||||||
@ -11,11 +11,16 @@ type Props = {
|
|||||||
model: Product;
|
model: Product;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ExploreModelItem: React.FC<Props> = ({ model }) => {
|
export type Ref = HTMLDivElement;
|
||||||
|
|
||||||
|
const ExploreModelItem = forwardRef<Ref, Props>(({ model }, ref) => {
|
||||||
const [show, setShow] = useState(false);
|
const [show, setShow] = useState(false);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col border border-gray-200 rounded-md mb-4">
|
<div
|
||||||
|
ref={ref}
|
||||||
|
className="flex flex-col border border-gray-200 rounded-md mb-4"
|
||||||
|
>
|
||||||
<ExploreModelItemHeader
|
<ExploreModelItemHeader
|
||||||
name={model.name}
|
name={model.name}
|
||||||
status={TagType.Recommended}
|
status={TagType.Recommended}
|
||||||
@ -94,6 +99,6 @@ const ExploreModelItem: React.FC<Props> = ({ model }) => {
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
export default ExploreModelItem;
|
export default ExploreModelItem;
|
||||||
|
|||||||
@ -3,19 +3,35 @@ import ExploreModelItem from "../ExploreModelItem";
|
|||||||
import { modelSearchAtom } from "@/_helpers/JotaiWrapper";
|
import { modelSearchAtom } from "@/_helpers/JotaiWrapper";
|
||||||
import useGetHuggingFaceModel from "@/_hooks/useGetHuggingFaceModel";
|
import useGetHuggingFaceModel from "@/_hooks/useGetHuggingFaceModel";
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
|
import { useInView } from "react-intersection-observer";
|
||||||
|
|
||||||
const ExploreModelList: React.FC = () => {
|
const ExploreModelList: React.FC = () => {
|
||||||
const modelSearch = useAtomValue(modelSearchAtom);
|
const modelSearch = useAtomValue(modelSearchAtom);
|
||||||
const { modelList, getHuggingFaceModel } = useGetHuggingFaceModel();
|
const { modelList, getHuggingFaceModel } = useGetHuggingFaceModel();
|
||||||
|
const { ref, inView } = useInView({
|
||||||
|
threshold: 0,
|
||||||
|
triggerOnce: true,
|
||||||
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
getHuggingFaceModel(modelSearch);
|
getHuggingFaceModel(modelSearch);
|
||||||
}, [modelSearch]);
|
}, [modelSearch]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (inView) {
|
||||||
|
console.debug("Load more models..");
|
||||||
|
getHuggingFaceModel(modelSearch);
|
||||||
|
}
|
||||||
|
}, [inView]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex-1 overflow-y-auto scroll">
|
<div className="flex-1 overflow-y-auto scroll">
|
||||||
{modelList.map((item) => (
|
{modelList.map((item, index) => (
|
||||||
<ExploreModelItem key={item.id} model={item} />
|
<ExploreModelItem
|
||||||
|
ref={index === modelList.length - 1 ? ref : null}
|
||||||
|
key={item.id}
|
||||||
|
model={item}
|
||||||
|
/>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -17,7 +17,6 @@ const ModelDownloadingRow: React.FC<Props> = ({ downloadState }) => (
|
|||||||
>
|
>
|
||||||
<td className="flex flex-col whitespace-nowrap px-6 py-4 text-sm font-medium text-gray-900">
|
<td className="flex flex-col whitespace-nowrap px-6 py-4 text-sm font-medium text-gray-900">
|
||||||
{downloadState.fileName}
|
{downloadState.fileName}
|
||||||
<span className="text-gray-500 font-normal">model.version</span>
|
|
||||||
</td>
|
</td>
|
||||||
<td className="whitespace-nowrap px-6 py-4 text-sm text-gray-500">
|
<td className="whitespace-nowrap px-6 py-4 text-sm text-gray-500">
|
||||||
{toGigabytes(downloadState.size.transferred)}
|
{toGigabytes(downloadState.size.transferred)}
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import { useEffect, useState } from "react";
|
|||||||
import { executeSerial } from "../../../electron/core/plugin-manager/execution/extension-manager";
|
import { executeSerial } from "../../../electron/core/plugin-manager/execution/extension-manager";
|
||||||
import { DataService, ModelManagementService } from "../../shared/coreService";
|
import { DataService, ModelManagementService } from "../../shared/coreService";
|
||||||
import { SearchModelParamHf } from "@/_models/hf/SearchModelParam.hf";
|
import { SearchModelParamHf } from "@/_models/hf/SearchModelParam.hf";
|
||||||
import { ListModelOutputHf } from "@/_models/hf/ListModelOutput.hf";
|
|
||||||
|
|
||||||
export function useGetDownloadedModels() {
|
export function useGetDownloadedModels() {
|
||||||
const [downloadedModels, setDownloadedModels] = useState<Product[]>([]);
|
const [downloadedModels, setDownloadedModels] = useState<Product[]>([]);
|
||||||
@ -33,13 +32,13 @@ export async function getModelFiles(): Promise<Product[]> {
|
|||||||
|
|
||||||
export async function searchHfModels(
|
export async function searchHfModels(
|
||||||
params: SearchModelParamHf
|
params: SearchModelParamHf
|
||||||
): Promise<Product[]> {
|
): Promise<QueryProductResult> {
|
||||||
const result: ListModelOutputHf[] = await executeSerial(
|
const result = await executeSerial(
|
||||||
ModelManagementService.SEARCH_HF_MODELS,
|
ModelManagementService.SEARCH_HF_MODELS,
|
||||||
params
|
params
|
||||||
);
|
);
|
||||||
|
|
||||||
const products: Product[] = result.map((model) => {
|
const products: Product[] = result.data.map((model: any) => {
|
||||||
const modelVersions: ModelVersion[] = [];
|
const modelVersions: ModelVersion[] = [];
|
||||||
|
|
||||||
for (const [, file] of Object.entries(model.files)) {
|
for (const [, file] of Object.entries(model.files)) {
|
||||||
@ -78,5 +77,14 @@ export async function searchHfModels(
|
|||||||
return p;
|
return p;
|
||||||
});
|
});
|
||||||
|
|
||||||
return products;
|
return {
|
||||||
|
data: products,
|
||||||
|
hasMore: result.hasMore,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO define somewhere else
|
||||||
|
export type QueryProductResult = {
|
||||||
|
data: Product[];
|
||||||
|
hasMore: boolean;
|
||||||
|
};
|
||||||
|
|||||||
@ -5,6 +5,9 @@ import { Product } from "@/_models/Product";
|
|||||||
|
|
||||||
export default function useGetHuggingFaceModel() {
|
export default function useGetHuggingFaceModel() {
|
||||||
const [modelList, setModelList] = useState<Product[]>([]);
|
const [modelList, setModelList] = useState<Product[]>([]);
|
||||||
|
const [currentOwner, setCurrentOwner] = useState<string | undefined>(
|
||||||
|
undefined
|
||||||
|
);
|
||||||
|
|
||||||
const getHuggingFaceModel = async (owner?: string) => {
|
const getHuggingFaceModel = async (owner?: string) => {
|
||||||
if (!owner) {
|
if (!owner) {
|
||||||
@ -14,11 +17,16 @@ export default function useGetHuggingFaceModel() {
|
|||||||
|
|
||||||
const searchParams: SearchModelParamHf = {
|
const searchParams: SearchModelParamHf = {
|
||||||
search: { owner },
|
search: { owner },
|
||||||
limit: 10,
|
limit: 5,
|
||||||
};
|
};
|
||||||
const result = await searchHfModels(searchParams);
|
const result = await searchHfModels(searchParams);
|
||||||
console.debug("result", JSON.stringify(result));
|
console.debug("result", JSON.stringify(result));
|
||||||
setModelList(result);
|
if (owner !== currentOwner) {
|
||||||
|
setModelList(result.data);
|
||||||
|
setCurrentOwner(owner);
|
||||||
|
} else {
|
||||||
|
setModelList([...modelList, ...result.data]);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return { modelList, getHuggingFaceModel };
|
return { modelList, getHuggingFaceModel };
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user