Merge pull request #409 from janhq/393-nitro-new-version-integration-inference-params-update-config-in-runtime
feat: Support for nitro release 0.1.2
This commit is contained in:
commit
28f8b8e032
1
plugins/inference-plugin/@types/global.d.ts
vendored
1
plugins/inference-plugin/@types/global.d.ts
vendored
@ -1,2 +1,3 @@
|
||||
declare const PLUGIN_NAME: string;
|
||||
declare const MODULE_PATH: string;
|
||||
declare const INFERENCE_URL: string;
|
||||
|
||||
@ -1,17 +1,23 @@
|
||||
import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core";
|
||||
import {
|
||||
EventName,
|
||||
InferenceService,
|
||||
NewMessageRequest,
|
||||
PluginService,
|
||||
events,
|
||||
store,
|
||||
invokePluginFunc,
|
||||
} from "@janhq/core";
|
||||
import { Observable } from "rxjs";
|
||||
|
||||
const inferenceUrl = "http://localhost:3928/llama/chat_completion";
|
||||
|
||||
const initModel = async (product) => core.invokePluginFunc(MODULE_PATH, "initModel", product);
|
||||
const initModel = async (product) => invokePluginFunc(MODULE_PATH, "initModel", product);
|
||||
|
||||
const stopModel = () => {
|
||||
core.invokePluginFunc(MODULE_PATH, "killSubprocess");
|
||||
invokePluginFunc(MODULE_PATH, "killSubprocess");
|
||||
};
|
||||
|
||||
function requestInference(recentMessages: any[]): Observable<string> {
|
||||
return new Observable((subscriber) => {
|
||||
fetch(inferenceUrl, {
|
||||
fetch(INFERENCE_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@ -5,8 +5,8 @@ const fs = require("fs");
|
||||
const tcpPortUsed = require("tcp-port-used");
|
||||
const { killPortProcess } = require("kill-port-process");
|
||||
|
||||
let subprocess = null;
|
||||
const PORT = 3928;
|
||||
let subprocess = null;
|
||||
|
||||
const initModel = (fileName) => {
|
||||
return (
|
||||
@ -15,9 +15,7 @@ const initModel = (fileName) => {
|
||||
reject("Model not found, please download again.");
|
||||
}
|
||||
if (subprocess) {
|
||||
console.error(
|
||||
"A subprocess is already running. Attempt to kill then reinit."
|
||||
);
|
||||
console.error("A subprocess is already running. Attempt to kill then reinit.");
|
||||
killSubprocess();
|
||||
}
|
||||
resolve(fileName);
|
||||
@ -32,35 +30,13 @@ const initModel = (fileName) => {
|
||||
// Spawn Nitro subprocess to load model
|
||||
.then(() => {
|
||||
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
|
||||
|
||||
// Read the existing config
|
||||
const configFilePath = path.join(binaryFolder, "config", "config.json");
|
||||
let config: any = {};
|
||||
if (fs.existsSync(configFilePath)) {
|
||||
const rawData = fs.readFileSync(configFilePath, "utf-8");
|
||||
config = JSON.parse(rawData);
|
||||
}
|
||||
|
||||
// Update the llama_model_path
|
||||
if (!config.custom_config) {
|
||||
config.custom_config = {};
|
||||
}
|
||||
|
||||
const modelPath = path.join(app.getPath("userData"), fileName);
|
||||
|
||||
config.custom_config.llama_model_path = modelPath;
|
||||
|
||||
// Write the updated config back to the file
|
||||
fs.writeFileSync(configFilePath, JSON.stringify(config, null, 4));
|
||||
|
||||
let binaryName;
|
||||
|
||||
if (process.platform === "win32") {
|
||||
binaryName = "nitro_windows_amd64.exe";
|
||||
binaryName = "nitro_windows_amd64_cuda.exe";
|
||||
} else if (process.platform === "darwin") {
|
||||
// Mac OS platform
|
||||
binaryName =
|
||||
process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_amd64";
|
||||
binaryName = process.arch === "arm64" ? "nitro_mac_arm64" : "nitro_mac_intel";
|
||||
} else {
|
||||
// Linux
|
||||
binaryName = "nitro_linux_amd64_cuda"; // For other platforms
|
||||
@ -69,8 +45,7 @@ const initModel = (fileName) => {
|
||||
const binaryPath = path.join(binaryFolder, binaryName);
|
||||
|
||||
// Execute the binary
|
||||
|
||||
subprocess = spawn(binaryPath, [configFilePath], { cwd: binaryFolder });
|
||||
subprocess = spawn(binaryPath, { cwd: binaryFolder });
|
||||
|
||||
// Handle subprocess output
|
||||
subprocess.stdout.on("data", (data) => {
|
||||
@ -88,7 +63,29 @@ const initModel = (fileName) => {
|
||||
})
|
||||
.then(() => tcpPortUsed.waitUntilUsed(PORT, 300, 30000))
|
||||
.then(() => {
|
||||
return {};
|
||||
const llama_model_path = path.join(app.getPath("userData"), fileName);
|
||||
|
||||
const config = {
|
||||
llama_model_path,
|
||||
ctx_len: 2048,
|
||||
ngl: 100,
|
||||
embedding: true, // Always enable embedding mode on
|
||||
};
|
||||
|
||||
// Load model config
|
||||
return fetch(`http://127.0.0.1:${PORT}/inferences/llamacpp/loadmodel`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(config),
|
||||
});
|
||||
})
|
||||
.then((res) => {
|
||||
if (res.ok) {
|
||||
return {};
|
||||
}
|
||||
throw new Error("Nitro: Model failed to load.");
|
||||
})
|
||||
.catch((err) => {
|
||||
return { error: err };
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
{
|
||||
"listeners": [
|
||||
{
|
||||
"address": "0.0.0.0",
|
||||
"port": 3928
|
||||
}
|
||||
],
|
||||
"custom_config": {
|
||||
"llama_model_path": "",
|
||||
"ctx_len": 2048,
|
||||
"ngl": 100
|
||||
}
|
||||
}
|
||||
@ -13,8 +13,8 @@ typedef struct {
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
|
||||
@ -132,6 +132,13 @@ kernel void kernel_relu(
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sqr(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
@ -338,10 +345,11 @@ kernel void kernel_rms_norm(
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
float4 sumf=0;
|
||||
float all_sum=0;
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
@ -354,6 +362,7 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
@ -361,7 +370,9 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||
sum[0] += x_scalar[i];
|
||||
}
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
|
||||
@ -376,7 +387,9 @@ kernel void kernel_rms_norm(
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||
y_scalar[i00] = x_scalar[i00] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
||||
}
|
||||
|
||||
// putting them in the kernel cause a significant performance penalty
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
//Note: This is a template, but strictly speaking it only applies to
|
||||
// quantizations where the block size is 32. It also does not
|
||||
@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
||||
uint3 tgpig, uint tiisg, uint sgitg) {
|
||||
const int nb = ne00/QK4_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||
|
||||
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
||||
|
||||
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[nr]={0.f};
|
||||
|
||||
const int ix = tiisg/2;
|
||||
const int il = 8*(tiisg%2);
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[nr] = {0.f};
|
||||
|
||||
const int ix = (tiisg/2);
|
||||
const int il = (tiisg%2)*8;
|
||||
|
||||
device const float * yb = y + ix * QK4_0 + il;
|
||||
|
||||
@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
||||
sumy += yb[i] + yb[i+1];
|
||||
yl[i+0] = yb[i+ 0];
|
||||
yl[i+1] = yb[i+ 1]/256.f;
|
||||
|
||||
sumy += yb[i+16] + yb[i+17];
|
||||
yl[i+8] = yb[i+16]/16.f;
|
||||
yl[i+9] = yb[i+17]/4096.f;
|
||||
@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (tiisg == 0 && first_row + row < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
||||
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_0_f32(
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||
constant int64_t & ne1[[buffer(16)]],
|
||||
constant uint & gqa[[buffer(17)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_1_f32(
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||
|
||||
#define NB_Q8_0 8
|
||||
|
||||
kernel void kernel_mul_mat_q8_0_f32(
|
||||
kernel void kernel_mul_mv_q8_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
||||
|
||||
#define N_F32_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f32_f32(
|
||||
kernel void kernel_mul_mv_f32_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32_1row(
|
||||
kernel void kernel_mul_mv_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
||||
|
||||
#define N_F16_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
kernel void kernel_mul_mv_f16_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
}
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
kernel void kernel_mul_mat_f16_f32_l4(
|
||||
kernel void kernel_mul_mv_f16_f32_l4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -1098,6 +1117,62 @@ kernel void kernel_cpy_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_concat(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
if (i02 < ne02) {
|
||||
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
||||
src0_ptr += ntg.x*nb00;
|
||||
} else {
|
||||
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
||||
src1_ptr += ntg.x*nb10;
|
||||
}
|
||||
dst_ptr += ntg.x*nb0;
|
||||
}
|
||||
}
|
||||
|
||||
//============================================ k-quants ======================================================
|
||||
|
||||
#ifndef QK_K
|
||||
@ -1190,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||
|
||||
//====================================== dot products =========================
|
||||
|
||||
kernel void kernel_mul_mat_q2_K_f32(
|
||||
kernel void kernel_mul_mv_q2_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1334,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
kernel void kernel_mul_mv_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1486,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
}
|
||||
}
|
||||
#else
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
kernel void kernel_mul_mv_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1557,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1663,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
}
|
||||
}
|
||||
#else
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1752,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
}
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mat_q5_K_f32(
|
||||
kernel void kernel_mul_mv_q5_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1925,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q6_K_f32(
|
||||
kernel void kernel_mul_mv_q6_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -2263,7 +2338,7 @@ kernel void kernel_get_rows(
|
||||
}
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
||||
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
||||
@ -2300,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
const uint r0 = tgpig.y;
|
||||
const uint r1 = tgpig.x;
|
||||
const uint im = tgpig.z;
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
@ -2326,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
//load data and store to threadgroup memory
|
||||
// load data and store to threadgroup memory
|
||||
half4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (int i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
}
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
||||
= *((device float2x4 *)y);
|
||||
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||
y += BLOCK_SIZE_K;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//load matrices from threadgroup memory and conduct outer products
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||
#pragma unroll(4)
|
||||
@ -2360,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
|
||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
||||
@ -2368,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
||||
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
|
||||
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
||||
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||
if (sgitg==0) {
|
||||
|
||||
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||
if (sgitg == 0) {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
||||
}
|
||||
}
|
||||
@ -2440,4 +2523,4 @@ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe
Normal file
BIN
plugins/inference-plugin/nitro/nitro_windows_amd64_cuda.exe
Normal file
Binary file not shown.
@ -19,6 +19,7 @@ module.exports = {
|
||||
new webpack.DefinePlugin({
|
||||
PLUGIN_NAME: JSON.stringify(packageJson.name),
|
||||
MODULE_PATH: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||
INFERENCE_URL: JSON.stringify("http://127.0.0.1:3928/inferences/llamacpp/chat_completion"),
|
||||
}),
|
||||
],
|
||||
output: {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user