implement local Nomic Embed via llama.cpp (#2086)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
pull/2053/head
Jared Van Bortel 2 months ago committed by GitHub
parent 171f4e488e
commit 406e88b59a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -97,11 +97,6 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS)
add_library(gptj-${BUILD_VARIANT} SHARED
gptj.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
prepare_target(gptj llama-mainline)
add_library(bert-${BUILD_VARIANT} SHARED
bert.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
target_compile_definitions(bert-${BUILD_VARIANT} PRIVATE LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
prepare_target(bert llama-mainline)
endif()
endforeach()

@ -1,910 +0,0 @@
#define BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#include "bert_impl.h"
#include "llmodel_shared.h"
#include "ggml.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <iostream>
#include <regex>
#include <thread>
#include <algorithm>
#include <numeric>
//#define DEBUG_BERT
namespace {
const char *modelType_ = "Bert";
}
typedef int32_t bert_vocab_id;
// default hparams (all-MiniLM-L6-v2)
struct bert_hparams
{
int32_t n_vocab = 30522;
int32_t n_max_tokens = 512;
int32_t n_embd = 256;
int32_t n_intermediate = 1536;
int32_t n_head = 12;
int32_t n_layer = 6;
};
struct bert_layer
{
// normalization
struct ggml_tensor *ln_att_w;
struct ggml_tensor *ln_att_b;
struct ggml_tensor *ln_out_w;
struct ggml_tensor *ln_out_b;
// attention
struct ggml_tensor *q_w;
struct ggml_tensor *q_b;
struct ggml_tensor *k_w;
struct ggml_tensor *k_b;
struct ggml_tensor *v_w;
struct ggml_tensor *v_b;
struct ggml_tensor *o_w;
struct ggml_tensor *o_b;
// ff
struct ggml_tensor *ff_i_w;
struct ggml_tensor *ff_i_b;
struct ggml_tensor *ff_o_w;
struct ggml_tensor *ff_o_b;
};
struct bert_vocab
{
std::map<std::string, bert_vocab_id> token_to_id;
std::map<std::string, bert_vocab_id> subword_token_to_id;
std::map<bert_vocab_id, std::string> _id_to_token;
std::map<bert_vocab_id, std::string> _id_to_subword_token;
};
struct bert_model
{
bert_hparams hparams;
// embeddings weights
struct ggml_tensor *word_embeddings;
struct ggml_tensor *token_type_embeddings;
struct ggml_tensor *position_embeddings;
struct ggml_tensor *ln_e_w;
struct ggml_tensor *ln_e_b;
std::vector<bert_layer> layers;
struct ggml_context *ctx;
};
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
struct bert_ctx
{
bert_model model;
bert_vocab vocab;
size_t mem_per_token;
int64_t mem_per_input;
int32_t max_batch_n;
llm_buffer buf_compute;
llm_buffer work_buf;
};
int32_t bert_n_embd(bert_ctx * ctx)
{
return ctx->model.hparams.n_embd;
}
int32_t bert_n_max_tokens(bert_ctx * ctx)
{
return ctx->model.hparams.n_max_tokens;
}
const char* bert_vocab_id_to_token(bert_ctx * ctx, bert_vocab_id id) {
bert_vocab & vocab = ctx->vocab;
auto it = vocab._id_to_token.find(id);
if (it != vocab._id_to_token.end())
{
return it->second.c_str();
}
it = vocab._id_to_subword_token.find(id);
if (it != vocab._id_to_subword_token.end())
{
return it->second.c_str();
}
return "[UNK TOKEN from bert_vocab]";
}
//
// Tokenizing
//
static size_t utf8_len(char src)
{
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
std::string stripAccents(const std::string &inputString)
{
std::string resultString;
std::map<std::string, char> accentMap = {{"À", 'A'},{"Á", 'A'},
{"Â", 'A'},{"Ã", 'A'},{"Ä", 'A'},{"Å", 'A'},{"à", 'a'},{"á", 'a'},
{"â", 'a'},{"ã", 'a'},{"ä", 'a'},{"å", 'a'},{"È", 'E'},{"É", 'E'},
{"Ê", 'E'},{"Ë", 'E'},{"è", 'e'},{"é", 'e'},{"ê", 'e'},{"ë", 'e'},
{"Ì", 'I'},{"Í", 'I'},{"Î", 'I'},{"Ï", 'I'},{"ì", 'i'},{"í", 'i'},
{"î", 'i'},{"ï", 'i'},{"Ò", 'O'},{"Ó", 'O'},{"Ô", 'O'},{"Õ", 'O'},
{"Ö", 'O'},{"ò", 'o'},{"ó", 'o'},{"ô", 'o'},{"õ", 'o'},{"ö", 'o'},
{"Ù", 'U'},{"Ú", 'U'},{"Û", 'U'},{"Ü", 'U'},{"ù", 'u'},{"ú", 'u'},
{"û", 'u'},{"ü", 'u'},{"Ý", 'Y'},{"ý", 'y'},{"Ç", 'C'},{"ç", 'c'},
{"Ñ", 'N'},{"ñ", 'n'},
};
for (size_t i = 0; i < inputString.length();)
{
int len = utf8_len(inputString[i]);
std::string curChar = inputString.substr(i, len);
auto iter = accentMap.find(curChar);
if (iter != accentMap.end())
{
resultString += iter->second;
}
else
{
resultString += curChar;
}
i += len;
}
return resultString;
}
std::string bert_normalize_prompt(const std::string &text)
{
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
std::string text2 = stripAccents(text);
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i]))
{
char c = text2[i];
if (c >= 'A' && c <= 'Z')
text2[i] = c - 'A' + 'a';
}
return text2;
}
std::vector<bert_vocab_id> bert_tokenize(
struct bert_ctx * ctx,
const char * text)
{
const bert_vocab &vocab = ctx->vocab;
std::string str = text;
std::vector<std::string> words;
// first split the text into words
{
str = bert_normalize_prompt(str);
std::string pat = R"([[:punct:]]|[[:alpha:]]+|[[:digit:]]+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re))
{
for (std::string x : m)
{
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<bert_vocab_id> tokens;
int cls_tok_id = 101;
tokens.push_back(cls_tok_id);
for (const auto &word : words)
{
if (word.size() == 0)
continue;
int i = 0;
int n = word.size();
auto *token_map = &vocab.token_to_id;
while (i < n)
{
int j = n;
while (j > i)
{
auto it = token_map->find(word.substr(i, j - i));
if (it != token_map->end())
{
tokens.push_back(it->second);
i = j;
token_map = &vocab.subword_token_to_id;
}
--j;
}
if (j == i)
{
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
token_map = &vocab.subword_token_to_id;
++i;
}
}
}
return tokens;
}
void bert_resize_ctx(bert_ctx * ctx, int32_t new_size) {
int64_t buf_size_new = ctx->mem_per_input * new_size;
// TODO: Max memory should be a param? Now just 1 GB
int64_t GB = 1 << 30;
#if defined(DEBUG_BERT)
printf("%s: requested_buf_size %lldMB\n", __func__, buf_size_new / (1 << 20));
#endif
if (buf_size_new > GB) {
int32_t adjusted_new_size = GB / ctx->mem_per_input;
if (adjusted_new_size < 1) adjusted_new_size = 1;
#if defined(DEBUG_BERT)
printf("%s: requested batch size %d, actual new batch size %d\n", __func__, new_size, adjusted_new_size);
#endif
new_size = adjusted_new_size;
buf_size_new = ctx->mem_per_input * new_size;
}
if (new_size > ctx->max_batch_n) {
ctx->buf_compute.resize(buf_size_new);
ctx->max_batch_n = new_size;
}
}
void bert_eval(
struct bert_ctx *ctx,
int32_t n_threads,
const bert_vocab_id *raw_tokens,
int32_t n_tokens,
float *embeddings)
{
const bert_model& model = ctx->model;
bool mem_req_mode = !embeddings;
// batch_embeddings is nullptr for the initial memory requirements run
if (!mem_req_mode && 1 > ctx->max_batch_n)
bert_resize_ctx(ctx, 1);
const int N = n_tokens;
const auto &tokens = raw_tokens;
const auto &hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_max_tokens = hparams.n_max_tokens;
const int n_head = hparams.n_head;
const int d_head = n_embd / n_head;
std::vector<float> result;
if (N > n_max_tokens)
{
fprintf(stderr, "Too many tokens, maximum is %d\n", n_max_tokens);
return;
}
auto & mem_per_token = ctx->mem_per_token;
auto & buf_compute = ctx->buf_compute;
struct ggml_init_params params = {
.mem_size = buf_compute.size,
.mem_buffer = buf_compute.addr,
.no_alloc = false,
};
struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph *gf = ggml_new_graph(ctx0);
// Embeddings. word_embeddings + token_type_embeddings + position_embeddings
struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(token_layer->data, tokens, N * ggml_element_size(token_layer));
struct ggml_tensor *token_types = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_set_zero(token_types);
struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; i++)
{
ggml_set_i32_1d(positions, i, i);
}
struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.word_embeddings, token_layer);
inpL = ggml_add(ctx0,
ggml_get_rows(ctx0, model.token_type_embeddings, token_types),
inpL);
inpL = ggml_add(ctx0,
ggml_get_rows(ctx0, model.position_embeddings, positions),
inpL);
// embd norm
{
inpL = ggml_norm(ctx0, inpL, 1e-12f);
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_e_w, inpL),
inpL),
ggml_repeat(ctx0, model.ln_e_b, inpL));
}
// layers
for (int il = 0; il < n_layer; il++)
{
struct ggml_tensor *cur = inpL;
// self-attention
{
struct ggml_tensor *Qcur = cur;
Qcur = ggml_reshape_3d(ctx0,
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, Qcur),
ggml_mul_mat(ctx0, model.layers[il].q_w, Qcur)),
d_head, n_head, N);
struct ggml_tensor *Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor *Kcur = cur;
Kcur = ggml_reshape_3d(ctx0,
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, Kcur),
ggml_mul_mat(ctx0, model.layers[il].k_w, Kcur)),
d_head, n_head, N);
struct ggml_tensor *K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
struct ggml_tensor *Vcur = cur;
Vcur = ggml_reshape_3d(ctx0,
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, Vcur),
ggml_mul_mat(ctx0, model.layers[il].v_w, Vcur)),
d_head, n_head, N);
struct ggml_tensor *V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);
// KQ = soft_max(KQ / sqrt(head width))
KQ = ggml_soft_max(
ctx0, ggml_scale(ctx0, KQ, 1.0f / sqrt((float)d_head))
);
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cpy(ctx0,
KQV,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// attention output
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].o_b, cur),
ggml_mul_mat(ctx0, model.layers[il].o_w, cur));
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
// attention norm
{
cur = ggml_norm(ctx0, cur, 1e-12f);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_att_w, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_att_b, cur));
}
struct ggml_tensor *att_output = cur;
// intermediate_output = self.intermediate(attention_output)
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].ff_i_b, cur),
cur);
cur = ggml_gelu(ctx0, cur);
// layer_output = self.output(intermediate_output, attention_output)
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].ff_o_b, cur),
cur);
// attentions bypass the intermediate layer
cur = ggml_add(ctx0, att_output, cur);
// output norm
{
cur = ggml_norm(ctx0, cur, 1e-12f);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_out_w, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_out_b, cur));
}
inpL = cur;
}
inpL = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
// pooler
struct ggml_tensor *sum = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, 1);
ggml_set_f32(sum, 1.0f / N);
inpL = ggml_mul_mat(ctx0, inpL, sum);
ggml_tensor *output = inpL;
// run the computation
ggml_build_forward_expand(gf, output);
//ggml_graph_compute_g4a()
ggml_graph_compute_g4a(ctx->work_buf, gf, n_threads);
//ggml_graph_compute(ctx0, gf);
// float *dat = ggml_get_data_f32(output);
// pretty_print_tensor(dat, output->ne, output->nb, output->n_dims - 1, "");
#ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes)
// requires GGML_PERF to be defined
ggml_graph_print(gf);
#endif
if (!mem_req_mode) {
memcpy(embeddings, (float *)ggml_get_data(output), sizeof(float) * n_embd);
} else {
mem_per_token = ggml_used_mem(ctx0) / N;
}
// printf("used_mem = %zu KB \n", ggml_used_mem(ctx0) / 1024);
// printf("mem_per_token = %zu KB \n", mem_per_token / 1024);
ggml_free(ctx0);
}
//
// Loading and setup
//
void bert_free(bert_ctx * ctx) {
delete ctx;
}
struct bert_ctx * bert_load_from_file(const char *fname)
{
#if defined(DEBUG_BERT)
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname);
#endif
bert_ctx * new_bert = new bert_ctx;
bert_model & model = new_bert->model;
bert_vocab & vocab = new_bert->vocab;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &model.ctx,
};
gguf_context *ggufctx = gguf_init_from_file(fname, params);
if (!ggufctx) {
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
return nullptr;
}
printf("%s: gguf version = %d\n", __func__, gguf_get_version(ggufctx));
printf("%s: gguf alignment = %zu\n", __func__, gguf_get_alignment(ggufctx));
printf("%s: gguf data offset = %zu\n", __func__, gguf_get_data_offset(ggufctx));
// print some standard metadata
{
int keyidx;
keyidx = gguf_find_key(ggufctx, "general.name");
if (keyidx != -1) { printf("%s: model name = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.description");
if (keyidx != -1) { printf("%s: model description = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.author");
if (keyidx != -1) { printf("%s: model author = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.license");
if (keyidx != -1) { printf("%s: model license = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.architecture");
if (keyidx != -1) { printf("%s: model architecture = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.file_type");
if (keyidx != -1) { printf("%s: model file type = %" PRIu32 "\n", __func__, gguf_get_val_u32(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "gptneox.tensor_data_layout");
if (keyidx != -1) { printf("%s: model data layout = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.source.huggingface.repository");
if (keyidx != -1) { printf("%s: model source HF repo = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
}
// check required metadata
{
// check model architecture kv
int keyidx = gguf_find_key(ggufctx, "general.architecture");
if (keyidx == -1) {
fprintf(stderr, "%s: gguf model architecture not found!\n", __func__);
return nullptr;
}
if (strcmp(gguf_get_val_str(ggufctx, keyidx), "bert") != 0) {
fprintf(stderr, "%s: model architecture not supported!\n", __func__);
return nullptr;
}
}
// load hparams
{
auto &hparams = model.hparams;
bool ok = false;
int keyidx;
do {
keyidx = gguf_find_key(ggufctx, "bert.context_length");
if (keyidx == -1) { break; }
hparams.n_max_tokens = gguf_get_val_u32(ggufctx, keyidx);
keyidx = gguf_find_key(ggufctx, "bert.embedding_length");
if (keyidx == -1) { break; }
hparams.n_embd = gguf_get_val_u32(ggufctx, keyidx);
keyidx = gguf_find_key(ggufctx, "bert.feed_forward_length");
if (keyidx == -1) { break; }
hparams.n_intermediate = gguf_get_val_u32(ggufctx, keyidx);
keyidx = gguf_find_key(ggufctx, "bert.attention.head_count");
if (keyidx == -1) { break; }
hparams.n_head = gguf_get_val_u32(ggufctx, keyidx);
keyidx = gguf_find_key(ggufctx, "bert.block_count");
if (keyidx == -1) { break; }
hparams.n_layer = gguf_get_val_u32(ggufctx, keyidx);
ok = true;
} while (false);
if (!ok) {
fprintf(stderr, "%s: required hparam missing!\n", __func__);
return nullptr;
}
#if defined(DEBUG_BERT)
printf("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_intermediate = %d\n", __func__, hparams.n_intermediate);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
#endif
}
// load vocab
{
auto & hparams = model.hparams;
int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model");
if (keyidx == -1) {
fprintf(stderr, "%s: tokenizer model not found!\n", __func__);
return nullptr;
}
if (strcmp(gguf_get_val_str(ggufctx, keyidx), "bert") != 0) {
fprintf(stderr, "%s: tokenizer model not supported!\n", __func__);
return nullptr;
}
int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens");
if (tokens_keyidx == -1) {
fprintf(stderr, "%s: bert tokenizer vocab not found!\n", __func__);
return nullptr;
}
hparams.n_vocab = gguf_get_arr_n(ggufctx, tokens_keyidx);
printf("%s: bert tokenizer vocab = %d\n", __func__, int(hparams.n_vocab));
for (int i = 0; i < hparams.n_vocab; i++) {
std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i);
if (word[0] == '#' && word[1] == '#')
{
vocab.subword_token_to_id[word.substr(2)] = i;
vocab._id_to_subword_token[i] = word;
}
if (vocab.token_to_id.count(word) == 0)
{
vocab.token_to_id[word] = i;
vocab._id_to_token[i] = word;
}
}
}
auto &ctx = model.ctx;
#if defined(DEBUG_BERT)
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ggml_get_mem_size(ctx) / (1024.0 * 1024.0));
#endif
// prepare memory for the weights
{
const int n_layer = model.hparams.n_layer;
model.layers.resize(n_layer);
model.word_embeddings = ggml_get_tensor(ctx, "token_embd.weight");
model.token_type_embeddings = ggml_get_tensor(ctx, "token_types.weight");
model.position_embeddings = ggml_get_tensor(ctx, "position_embd.weight");
model.ln_e_w = ggml_get_tensor(ctx, "output_norm.weight");
model.ln_e_b = ggml_get_tensor(ctx, "output_norm.bias");
auto name = [](int i, std::string n) {
static std::string key;
key = "blk." + std::to_string(i) + "." + n;
return key.c_str();
};
for (int i = 0; i < n_layer; ++i)
{
auto &layer = model.layers[i];
layer.ln_att_w = ggml_get_tensor(ctx, name(i, "attn_norm.weight"));
layer.ln_att_b = ggml_get_tensor(ctx, name(i, "attn_norm.bias"));
layer.ln_out_w = ggml_get_tensor(ctx, name(i, "ffn_norm.weight"));
layer.ln_out_b = ggml_get_tensor(ctx, name(i, "ffn_norm.bias"));
layer.q_w = ggml_get_tensor(ctx, name(i, "attn_q.weight"));
layer.q_b = ggml_get_tensor(ctx, name(i, "attn_q.bias"));
layer.k_w = ggml_get_tensor(ctx, name(i, "attn_k.weight"));
layer.k_b = ggml_get_tensor(ctx, name(i, "attn_k.bias"));
layer.v_w = ggml_get_tensor(ctx, name(i, "attn_v.weight"));
layer.v_b = ggml_get_tensor(ctx, name(i, "attn_v.bias"));
layer.o_w = ggml_get_tensor(ctx, name(i, "attn_output.weight"));
layer.o_b = ggml_get_tensor(ctx, name(i, "attn_output.bias"));
layer.ff_i_w = ggml_get_tensor(ctx, name(i, "ffn_up.weight"));
layer.ff_i_b = ggml_get_tensor(ctx, name(i, "ffn_up.bias"));
layer.ff_o_w = ggml_get_tensor(ctx, name(i, "ffn_down.weight"));
layer.ff_o_b = ggml_get_tensor(ctx, name(i, "ffn_down.bias"));
}
}
// Calculate space requirements for setting up context buffers later
{
bert_vocab_id tokens[] = {0, 1, 2, 3};
// TODO: We set the initial buffer size to 16MB and hope it's enough. Maybe there is a better way to do this?
new_bert->buf_compute.resize(16 * 1024 * 1024);
bert_eval(new_bert, 1, tokens, 4, nullptr);
new_bert->max_batch_n = 0;
// TODO: Max tokens should be a param?
int32_t N = new_bert->model.hparams.n_max_tokens;
new_bert->mem_per_input = 2.2 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead
}
#if defined(DEBUG_BERT)
printf("%s: mem_per_token %ld KB, mem_per_input %ld MB\n", __func__, new_bert->mem_per_token / (1 << 10), new_bert->mem_per_input / (1 << 20));
#endif
return new_bert;
}
struct BertPrivate {
const std::string modelPath;
bool modelLoaded;
bert_ctx *ctx = nullptr;
int64_t n_threads = 0;
};
Bert::Bert() : d_ptr(new BertPrivate) {
d_ptr->modelLoaded = false;
}
Bert::~Bert() {
bert_free(d_ptr->ctx);
}
bool Bert::loadModel(const std::string &modelPath, int n_ctx, int ngl)
{
(void)n_ctx;
(void)ngl;
d_ptr->modelLoaded = false;
auto * ctx = bert_load_from_file(modelPath.c_str());
fflush(stdout);
if (!ctx)
return false;
d_ptr->ctx = ctx;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true;
return true;
}
bool Bert::isModelLoaded() const
{
return d_ptr->modelLoaded;
}
size_t Bert::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
{
(void)modelPath;
(void)n_ctx;
(void)ngl;
return 0;
}
size_t Bert::stateSize() const
{
return 0;
}
size_t Bert::saveState(uint8_t */*dest*/) const
{
return 0;
}
size_t Bert::restoreState(const uint8_t */*src*/)
{
return 0;
}
void Bert::setThreadCount(int32_t n_threads)
{
d_ptr->n_threads = n_threads;
}
int32_t Bert::threadCount() const
{
return d_ptr->n_threads;
}
std::vector<float> Bert::embedding(const std::string &text)
{
const int overlap = 32;
const LLModel::Token clsToken = 101;
const size_t contextLength = bert_n_max_tokens(d_ptr->ctx);
typedef std::vector<LLModel::Token> TokenString;
TokenString tokens = ::bert_tokenize(d_ptr->ctx, text.c_str());
#if defined(DEBUG_BERT)
std::cerr << "embedding: " << tokens.size()
<< " contextLength " << contextLength
<< "\n";
#endif
std::vector<double> embeddingsSum(bert_n_embd(d_ptr->ctx), 0);
int embeddingsSumTotal = 0;
size_t start_pos = 0;
bool isFirstChunk = true;
while (start_pos < tokens.size()) {
TokenString chunk;
if (!isFirstChunk)
chunk.push_back(clsToken);
const size_t l = isFirstChunk ? contextLength : contextLength - 1;
if (tokens.size() - start_pos > l) {
chunk.insert(chunk.end(), tokens.begin() + start_pos, tokens.begin() + start_pos + l);
start_pos = start_pos + contextLength - overlap;
} else {
chunk.insert(chunk.end(), tokens.begin() + start_pos, tokens.end());
start_pos = tokens.size();
}
#if defined(DEBUG_BERT)
std::cerr << "chunk length: " << chunk.size()
<< " embeddingsSumTotal " << embeddingsSumTotal
<< " contextLength " << contextLength
<< " start_pos " << start_pos
<< "\n";
#endif
embeddingsSumTotal++;
std::vector<float> embeddings(bert_n_embd(d_ptr->ctx));
bert_eval(d_ptr->ctx, d_ptr->n_threads, chunk.data(), chunk.size(), embeddings.data());
std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddings.begin(), embeddingsSum.begin(), std::plus<float>());
isFirstChunk = false;
}
std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), [embeddingsSumTotal](float num){ return num / embeddingsSumTotal; });
double magnitude = std::sqrt(std::inner_product(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), 0.0));
for (auto &value : embeddingsSum)
value /= magnitude;
std::vector<float> finalEmbeddings(embeddingsSum.begin(), embeddingsSum.end());
return finalEmbeddings;
}
std::vector<LLModel::Token> Bert::tokenize(PromptContext &ctx, const std::string &str, bool special) const
{
(void)ctx;
(void)special;
return ::bert_tokenize(d_ptr->ctx, str.c_str());
}
LLModel::Token Bert::sampleToken(PromptContext &/*promptCtx*/) const
{
return 999 /*!*/;
}
std::string Bert::tokenToString(Token id) const
{
return bert_vocab_id_to_token(d_ptr->ctx, id);
}
bool Bert::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
{
std::vector<float> embeddings(bert_n_embd(d_ptr->ctx));
int32_t cls = 101;
const bool useCLS = tokens.front() != cls;
if (useCLS) {
std::vector<int32_t> myTokens;
myTokens.push_back(cls);
myTokens.insert(myTokens.end(), tokens.begin(), tokens.end());
bert_eval(d_ptr->ctx, d_ptr->n_threads, myTokens.data(), myTokens.size(), embeddings.data());
} else
bert_eval(d_ptr->ctx, d_ptr->n_threads, tokens.data(), tokens.size(), embeddings.data());
ctx.n_past = 0; // bert does not store any context
return true;
}
int32_t Bert::contextLength() const
{
return bert_n_max_tokens(d_ptr->ctx);
}
const std::vector<LLModel::Token> &Bert::endTokens() const
{
static const std::vector<LLModel::Token> out = { 102 /*sep*/};
return out;
}
std::string get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) {
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
}
return gguf_get_val_str(ctx_gguf, kid);
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT __attribute__ ((visibility ("default")))
#endif
extern "C" {
DLL_EXPORT bool is_g4a_backend_model_implementation() {
return true;
}
DLL_EXPORT const char *get_model_type() {
return modelType_;
}
DLL_EXPORT const char *get_build_variant() {
return GGML_BUILD_VARIANT;
}
DLL_EXPORT bool magic_match(const char * fname) {
struct ggml_context * ctx_meta = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ &ctx_meta,
};
gguf_context *ctx_gguf = gguf_init_from_file(fname, params);
if (!ctx_gguf)
return false;
bool isValid = gguf_get_version(ctx_gguf) <= 3;
isValid = isValid && get_arch_name(ctx_gguf) == "bert";
gguf_free(ctx_gguf);
return isValid;
}
DLL_EXPORT LLModel *construct() {
return new Bert;
}
}

@ -1,45 +0,0 @@
#ifndef BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#error This file is NOT meant to be included outside of bert.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#endif
#ifndef BERT_H
#define BERT_H
#include <string>
#include <functional>
#include <vector>
#include <memory>
#include "llmodel.h"
struct BertPrivate;
class Bert : public LLModel {
public:
Bert();
~Bert();
bool supportsEmbedding() const override { return true; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<float> embedding(const std::string &text) override;
private:
std::unique_ptr<BertPrivate> d_ptr;
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string tokenToString(Token id) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override { return true; }
};
#endif // BERT_H

@ -1 +1 @@
Subproject commit 2a086f71f5b570a0f047f88d88cf5704aae7ec7c
Subproject commit 43c20ce8004a4eac25ffe89e52bdf94bc7c47c02

@ -6,6 +6,7 @@
#include <cstdio>
#include <cstring>
#include <fstream>
#include <initializer_list>
#include <iomanip>
#include <iostream>
#include <map>
@ -30,6 +31,19 @@ static constexpr int GGUF_VER_MAX = 3;
static const char * const modelType_ = "LLaMA";
static const std::vector<const char *> KNOWN_ARCHES {
"baichuan", "bert", "bloom", "codeshell", "falcon", "gemma", "gpt2", "llama", "mpt", "nomic-bert", "orion",
"persimmon", "phi2", "plamo", "qwen", "qwen2", "refact", "stablelm", "starcoder"
};
static const std::vector<const char *> EMBEDDING_ARCHES {
"bert", "nomic-bert"
};
static bool is_embedding_arch(const std::string &arch) {
return std::find(EMBEDDING_ARCHES.begin(), EMBEDDING_ARCHES.end(), arch) < EMBEDDING_ARCHES.end();
}
static bool llama_verbose() {
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
return var && *var;
@ -124,7 +138,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
auto * ctx = load_gguf(modelPath.c_str());
if (!ctx)
return -1;
auto arch = get_arch_name(ctx);
std::string arch = get_arch_name(ctx);
int32_t value = -1;
if (ctx) {
@ -193,7 +207,7 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
return filesize + est_kvcache_size;
}
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const {
auto * ctx = load_gguf(modelPath.c_str());
if (!ctx) {
std::cerr << __func__ << ": failed to load " << modelPath << "\n";
@ -229,6 +243,18 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
return res;
}
bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const {
auto *ctx_gguf = load_gguf(modelPath.c_str());
if (!ctx_gguf) {
std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n";
return false;
}
std::string arch = get_arch_name(ctx_gguf);
gguf_free(ctx_gguf);
return is_embedding_arch(arch);
}
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
{
d_ptr->modelLoaded = false;
@ -287,20 +313,25 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
if (!d_ptr->model) {
fflush(stdout);
d_ptr->device = -1;
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
return false;
}
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}
// -- initialize the context --
d_ptr->ctx_params = llama_context_default_params();
bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model));
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (isEmbedding) {
d_ptr->ctx_params.n_batch = n_ctx_train;
} else {
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}
}
d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.type_k = params.kv_type;
@ -314,6 +345,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
if (m_supportsEmbedding)
d_ptr->ctx_params.embeddings = true;
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
if (!d_ptr->ctx) {
fflush(stdout);
@ -332,6 +366,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
}
#endif
m_supportsEmbedding = isEmbedding;
m_supportsCompletion = !isEmbedding;
fflush(stdout);
d_ptr->modelLoaded = true;
return true;
@ -535,6 +572,320 @@ bool LLamaModel::usingGPUDevice()
#endif
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
static void batch_add_seq(llama_batch &batch, const std::vector<LLModel::Token> &tokens, int seq_id) {
for (unsigned i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}
size_t LLamaModel::embeddingSize() const {
return llama_n_embd(d_ptr->model);
}
struct EmbModelSpec {
const char *docPrefix;
const char *queryPrefix;
std::vector<const char *> otherPrefixes = {};
bool matryoshkaCapable = false;
const char *recommendedDims = nullptr;
};
struct EmbModelGroup {
EmbModelSpec spec;
std::vector<const char *> names;
};
static const EmbModelSpec NOPREFIX_SPEC {nullptr, nullptr};
static const EmbModelSpec NOMIC_SPEC {"search_document", "search_query", {"clustering", "classification"}};
static const EmbModelSpec E5_SPEC {"passage", "query"};
static const EmbModelSpec NOMIC_1_5_SPEC {
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]"
};
static const EmbModelSpec LLM_EMBEDDER_SPEC {
"Represent this document for retrieval",
"Represent this query for retrieving relevant documents",
};
static const EmbModelSpec BGE_SPEC {
nullptr, "Represent this sentence for searching relevant passages",
};
static const EmbModelSpec E5_MISTRAL_SPEC {
nullptr, "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
};
static const EmbModelGroup EMBEDDING_MODEL_SPECS[] {
{NOPREFIX_SPEC, {"all-MiniLM-L6-v1", "all-MiniLM-L12-v1", "all-MiniLM-L6-v2", "all-MiniLM-L12-v2"}},
{NOMIC_SPEC, {"nomic-embed-text-v1", "nomic-embed-text-v1-ablated", "nomic-embed-text-v1-unsupervised"}},
{NOMIC_1_5_SPEC, {"nomic-embed-text-v1.5"}},
{LLM_EMBEDDER_SPEC, {"llm-embedder"}},
{BGE_SPEC, {"bge-small-en", "bge-base-en", "bge-large-en",
"bge-small-en-v1.5", "bge-base-en-v1.5", "bge-large-en-v1.5"}},
{E5_SPEC, {"e5-small", "e5-base", "e5-large",
"e5-small-unsupervised", "e5-base-unsupervised", "e5-large-unsupervised",
"e5-small-v2", "e5-base-v2", "e5-large-v2"}},
{E5_MISTRAL_SPEC, {"e5-mistral-7b-instruct",
"multilingual-e5-small", "multilingual-e5-base", "multilingual-e5-large",
"multilingual-e5-large-instruct"}},
};
static const EmbModelSpec *getEmbedSpec(const std::string &modelName) {
static const auto &specs = EMBEDDING_MODEL_SPECS;
auto it = std::find_if(specs, std::end(specs),
[&modelName](auto &spec) {
auto &names = spec.names;
return std::find(names.begin(), names.end(), modelName) < names.end();
}
);
return it < std::end(specs) ? &it->spec : nullptr;
}
void LLamaModel::embed(
const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality, bool doMean,
bool atlas
) {
const EmbModelSpec *spec;
std::optional<std::string> prefix;
if (d_ptr->model && (spec = getEmbedSpec(llama_model_name(d_ptr->model))))
prefix = isRetrieval ? spec->queryPrefix : spec->docPrefix;
embed(texts, embeddings, prefix, dimensionality, doMean, atlas);
}
void LLamaModel::embed(
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
bool doMean, bool atlas
) {
if (!d_ptr->model)
throw std::logic_error("no model is loaded");
const char *modelName = llama_model_name(d_ptr->model);
if (!m_supportsEmbedding)
throw std::logic_error("not an embedding model: "s + modelName);
auto *spec = getEmbedSpec(modelName);
if (!spec)
std::cerr << __func__ << ": warning: unknown model " << modelName << "\n";
const int32_t n_embd = llama_n_embd(d_ptr->model);
if (dimensionality < 0) {
dimensionality = n_embd;
} else if (spec && dimensionality != n_embd) {
auto msg = [dimensionality, modelName]() {
return "unsupported dimensionality " + std::to_string(dimensionality) + " for model " + modelName;
};
if (!spec->matryoshkaCapable)
throw std::logic_error(msg() + " (supported: " + std::to_string(n_embd) + ")");
if (dimensionality == 0 || dimensionality > n_embd)
throw std::logic_error(msg() + " (recommended: " + spec->recommendedDims + ")");
}
if (!prefix) {
if (spec) {
prefix = spec->docPrefix;
} else {
std::cerr << __func__ << ": warning: assuming no prefix\n";
prefix = "";
}
} else if (spec && prefix != spec->docPrefix && prefix != spec->queryPrefix &&
std::find(spec->otherPrefixes.begin(), spec->otherPrefixes.end(), *prefix) == spec->otherPrefixes.end())
{
std::stringstream ss;
ss << std::quoted(*prefix) << " is not a valid task type for model " << modelName;
throw std::logic_error(ss.str());
}
embedInternal(texts, embeddings, *prefix, dimensionality, doMean, atlas, spec);
}
// MD5 hash of "nomic empty"
static const char EMPTY_PLACEHOLDER[] = "24df574ea1c998de59d5be15e769658e";
auto product(double a) -> std::function<double(double)> {
return [a](double b) { return a * b; };
}
template <typename T>
double getL2NormScale(T *start, T *end) {
double magnitude = std::sqrt(std::inner_product(start, end, start, 0.0));
return 1.0 / std::max(magnitude, 1e-12);
}
void LLamaModel::embedInternal(
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
bool doMean, bool atlas, const EmbModelSpec *spec
) {
typedef std::vector<LLModel::Token> TokenString;
static constexpr int32_t atlasMaxLength = 8192;
static constexpr int chunkOverlap = 8; // Atlas overlaps n_batch-sized chunks of input by 8 tokens
const llama_token bos_token = llama_token_bos(d_ptr->model);
const llama_token eos_token = llama_token_eos(d_ptr->model);
assert(shouldAddBOS());
bool addEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
// no EOS, optional BOS
auto tokenize = [this, addEOS](std::string text, TokenString &tokens, bool addBOS) {
if (!text.empty() && text[0] != ' ')
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
tokens.resize(text.length()+4);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), addBOS, false);
assert(addEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
tokens.resize(n_tokens - addEOS); // erase EOS/SEP
};
// tokenize the texts
std::vector<TokenString> inputs;
for (unsigned i = 0; i < texts.size(); i++) {
auto &text = texts[i];
auto &inp = inputs.emplace_back();
tokenize(text, inp, false);
if (atlas && inp.size() > atlasMaxLength) {
if (doMean) {
throw std::logic_error(
"length of text at index " + std::to_string(i) + " is " + std::to_string(inp.size()) +
" tokens which exceeds limit of " + std::to_string(atlasMaxLength)
);
}
inp.resize(atlasMaxLength);
} else if (inp.empty()) {
if (!atlas || !text.empty()) {
std::cerr << __func__ << ": warning: chunking tokenized text at index " << std::to_string(i)
<< " into zero tokens\n";
}
tokenize(EMPTY_PLACEHOLDER, inp, false);
}
}
// tokenize the prefix
TokenString prefixTokens;
if (prefix.empty()) {
prefixTokens.push_back(bos_token);
} else {
tokenize(prefix + ':', prefixTokens, true);
}
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
const uint32_t max_len = n_batch - (prefixTokens.size() + addEOS); // minus BOS/CLS and EOS/SEP
if (chunkOverlap >= max_len) {
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
std::to_string(chunkOverlap) + " tokens");
}
// split into max_len-sized chunks
struct split_batch { int idx; TokenString batch; };
std::vector<split_batch> batches;
for (unsigned i = 0; i < inputs.size(); i++) {
auto &input = inputs[i];
for (auto it = input.begin(); it < input.end(); it += max_len) {
if (it > input.begin()) { it -= chunkOverlap; }
auto end = std::min(it + max_len, input.end());
auto &batch = batches.emplace_back(i, prefixTokens).batch;
batch.insert(batch.end(), it, end);
batch.push_back(eos_token);
if (!doMean) { break; /* limit text to one chunk */ }
}
}
inputs.clear();
// initialize batch
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// n_texts x n_embd matrix
const int32_t n_embd = llama_n_embd(d_ptr->model);
std::vector<double> embeddingsSum(texts.size() * n_embd);
std::vector<int> embeddingsSumTotal(texts.size());
std::vector<int> queued_indices; // text indices of batches to be processed
auto decode = [this, &queued_indices, n_embd, &batch, &embeddingsSum, &embeddingsSumTotal, spec, dimensionality]() {
if (llama_decode(d_ptr->ctx, batch) < 0)
throw std::runtime_error("llama_decode failed");
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i]) { continue; }
int i_prompt = queued_indices[batch.seq_id[i][0]];
auto *out = &embeddingsSum[i_prompt * n_embd];
// sequence embeddings aren't available when pooling_type is NONE
auto *embd = llama_get_embeddings_seq(d_ptr->ctx, batch.seq_id[i][0]);
if (!embd) { embd = llama_get_embeddings_ith(d_ptr->ctx, i); }
assert(embd);
auto *embd_end = embd + n_embd;
// layer normalization for nomic-embed-text-v1.5
if (spec && spec->matryoshkaCapable) {
// normalize mean
double mean = std::accumulate(embd, embd_end, 0.0) / n_embd;
std::transform(embd, embd_end, embd, [mean](double f){ return f - mean; });
// unbiased sample variance, with Bessel's correction
double variance = std::inner_product(embd, embd_end, embd, 0.0) / (n_embd - 1);
// trim to matryoshka dim
embd_end = embd + dimensionality;
// normalize variance
std::transform(embd, embd_end, embd, product(1.0 / std::sqrt(variance + 1e-5)));
}
// L2 norm
auto scale = getL2NormScale(embd, embd_end);
std::transform(embd, embd_end, out, out, [scale](double e, double o){ return o + scale * e; });
embeddingsSumTotal[i_prompt]++;
}
};
// break into batches
for (auto &inp: batches) {
// encode if at capacity
if (batch.n_tokens + inp.batch.size() > n_batch) {
decode();
batch.n_tokens = 0;
queued_indices.clear();
}
// add to batch
batch_add_seq(batch, inp.batch, queued_indices.size());
queued_indices.push_back(inp.idx);
}
// final batch
decode();
for (unsigned i = 0; i < texts.size(); i++) {
auto *embd = &embeddingsSum[i * n_embd];
auto *embd_end = embd + dimensionality;
int total = embeddingsSumTotal[i];
// average over chunks
std::transform(embd, embd_end, embd, product(1.0 / total));
// L2 norm and copy
auto scale = getL2NormScale(embd, embd_end);
std::transform(embd, embd_end, embeddings, product(scale));
embeddings += dimensionality;
}
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
@ -556,23 +907,21 @@ DLL_EXPORT const char *get_build_variant() {
DLL_EXPORT bool magic_match(const char *fname) {
auto * ctx = load_gguf(fname);
auto arch = get_arch_name(ctx);
std::string arch = get_arch_name(ctx);
bool valid = true;
static const std::vector<const char *> known_arches {
"baichuan", "bloom", "codeshell", "falcon", "gemma", "gpt2", "llama", "mpt", "orion", "persimmon", "phi2",
"plamo", "qwen", "qwen2", "refact", "stablelm", "starcoder"
};
if (std::find(known_arches.begin(), known_arches.end(), arch) == known_arches.end()) {
if (std::find(KNOWN_ARCHES.begin(), KNOWN_ARCHES.end(), arch) == KNOWN_ARCHES.end()) {
// not supported by this version of llama.cpp
if (!(arch == "gptj" || arch == "bert")) { // we support these via other modules
if (arch != "gptj") { // we support this via another module
std::cerr << __func__ << ": unsupported model architecture: " << arch << "\n";
}
valid = false;
}
if (valid && is_embedding_arch(arch) && gguf_find_key(ctx, (arch + ".pooling_type").c_str()) < 0)
valid = false; // old pre-llama.cpp embedding model, e.g. all-MiniLM-L6-v2-f16.gguf
gguf_free(ctx);
return valid;
}

@ -11,15 +11,18 @@
#include "llmodel.h"
struct LLamaPrivate;
struct EmbModelSpec;
class LLamaModel : public LLModel {
public:
LLamaModel();
~LLamaModel();
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool supportsEmbedding() const override { return m_supportsEmbedding; }
bool supportsCompletion() const override { return m_supportsCompletion; }
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelBlacklisted(const std::string &modelPath) override;
bool isModelBlacklisted(const std::string &modelPath) const override;
bool isEmbeddingModel(const std::string &modelPath) const override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
@ -29,12 +32,22 @@ public:
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired) const override;
bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const override;
bool initializeGPUDevice(int device, std::string *unavail_reason) const override;
bool initializeGPUDevice(int device, std::string *unavail_reason = nullptr) const override;
bool hasGPUDevice() override;
bool usingGPUDevice() override;
size_t embeddingSize() const override;
// user-specified prefix
void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
int dimensionality = -1, bool doMean = true, bool atlas = false) override;
// automatic prefix
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
bool doMean = true, bool atlas = false) override;
private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
@ -44,9 +57,11 @@ protected:
int32_t contextLength() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
int32_t layerCount(std::string const &modelPath) const override;
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
bool doMean, bool atlas, const EmbModelSpec *spec);
};
#endif // LLAMAMODEL_H

@ -213,21 +213,26 @@ LLModel *LLModel::Implementation::constructDefaultLlama() {
}
std::vector<LLModel::GPUDevice> LLModel::Implementation::availableGPUDevices() {
auto * llama = constructDefaultLlama();
auto *llama = constructDefaultLlama();
if (llama) { return llama->availableGPUDevices(0); }
return {};
}
int32_t LLModel::Implementation::maxContextLength(const std::string &modelPath) {
auto * llama = constructDefaultLlama();
auto *llama = constructDefaultLlama();
return llama ? llama->maxContextLength(modelPath) : -1;
}
int32_t LLModel::Implementation::layerCount(const std::string &modelPath) {
auto * llama = constructDefaultLlama();
auto *llama = constructDefaultLlama();
return llama ? llama->layerCount(modelPath) : -1;
}
bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath) {
auto *llama = constructDefaultLlama();
return llama && llama->isEmbeddingModel(modelPath);
}
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) {
s_implementations_search_path = path;
}

@ -1,13 +1,14 @@
#ifndef LLMODEL_H
#define LLMODEL_H
#include <string>
#include <functional>
#include <vector>
#include <string_view>
#include <fstream>
#include <cstdint>
#include <fstream>
#include <functional>
#include <limits>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
#define LLMODEL_MAX_PROMPT_BATCH 128
@ -44,6 +45,7 @@ public:
static std::vector<GPUDevice> availableGPUDevices();
static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath);
static bool isEmbeddingModel(const std::string &modelPath);
static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath();
@ -83,7 +85,8 @@ public:
virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual bool isModelBlacklisted(const std::string &modelPath) { (void)modelPath; return false; };
virtual bool isModelBlacklisted(const std::string &modelPath) const { (void)modelPath; return false; };
virtual bool isEmbeddingModel(const std::string &modelPath) const { (void)modelPath; return false; }
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const { return 0; }
@ -101,7 +104,15 @@ public:
bool special = false,
std::string *fakeReply = nullptr);
virtual std::vector<float> embedding(const std::string &text);
virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
}
// user-specified prefix
virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
int dimensionality = -1, bool doMean = true, bool atlas = false);
// automatic prefix
virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval,
int dimensionality = -1, bool doMean = true, bool atlas = false);
virtual void setThreadCount(int32_t n_threads) { (void)n_threads; }
virtual int32_t threadCount() const { return 1; }

@ -4,6 +4,7 @@
#include <cerrno>
#include <cstring>
#include <iostream>
#include <optional>
#include <utility>
struct LLModelWrapper {
@ -41,22 +42,22 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *build_va
*error = last_error_message.c_str();
}
}
return reinterpret_cast<llmodel_model*>(wrapper);
return wrapper;
}
void llmodel_model_destroy(llmodel_model model) {
delete reinterpret_cast<LLModelWrapper*>(model);
delete static_cast<LLModelWrapper *>(model);
}
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx, int ngl)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->requiredMem(model_path, n_ctx, ngl);
}
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, int ngl)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
std::string modelPath(model_path);
if (wrapper->llModel->isModelBlacklisted(modelPath)) {
@ -69,44 +70,28 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, i
bool llmodel_isModelLoaded(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->saveState(dest);
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->restoreState(src);
}
// Wrapper functions for the C callbacks
bool prompt_wrapper(int32_t token_id, void *user_data) {
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
return callback(token_id);
}
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
return callback(token_id, response.c_str());
}
bool recalculate_wrapper(bool is_recalculating, void *user_data) {
llmodel_recalculate_callback callback = reinterpret_cast<llmodel_recalculate_callback>(user_data);
return callback(is_recalculating);
}
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
@ -116,15 +101,11 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
bool special,
const char *fake_reply)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
// Create std::function wrappers that call the C function pointers
std::function<bool(int32_t)> prompt_func =
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
std::function<bool(int32_t, const std::string&)> response_func =
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
std::function<bool(bool)> recalc_func =
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
return response_callback(token_id, response.c_str());
};
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
@ -147,8 +128,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext,
special, fake_reply_p);
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
wrapper->promptContext, special, fake_reply_p);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
@ -171,38 +152,60 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
ctx->context_erase = wrapper->promptContext.contextErase;
}
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
{
if (model == nullptr || text == nullptr || !strlen(text)) {
*embedding_size = 0;
float *llmodel_embed(
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
bool do_mean, bool atlas, const char **error
) {
auto *wrapper = static_cast<LLModelWrapper *>(model);
if (!texts || !*texts) {
if (error)
*error = strdup("'texts' is NULL or empty");
return nullptr;
}
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
if (embedding == nullptr) {
*embedding_size = 0;
std::vector<std::string> textsVec;
while (*texts) { textsVec.emplace_back(*texts++); }
size_t embd_size;
float *embedding;
try {
embd_size = wrapper->llModel->embeddingSize();
if (dimensionality > 0 && dimensionality < int(embd_size))
embd_size = dimensionality;
embd_size *= textsVec.size();
std::optional<std::string> prefixStr;
if (prefix) { prefixStr = prefix; }
embedding = new float[embd_size];
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, do_mean, atlas);
} catch (std::exception const &e) {
if (error)
*error = strdup(e.what());
return nullptr;
}
std::copy(embeddingVector.begin(), embeddingVector.end(), embedding);
*embedding_size = embeddingVector.size();
*embedding_size = embd_size;
return embedding;
}
void llmodel_free_embedding(float *ptr)
{
free(ptr);
delete[] ptr;
}
void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
wrapper->llModel->setThreadCount(n_threads);
}
int32_t llmodel_threadCount(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->threadCount();
}
@ -218,7 +221,7 @@ const char *llmodel_get_implementation_search_path()
struct llmodel_gpu_device* llmodel_available_gpu_devices(llmodel_model model, size_t memoryRequired, int* num_devices)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
std::vector<LLModel::GPUDevice> devices = wrapper->llModel->availableGPUDevices(memoryRequired);
// Set the num_devices
@ -242,24 +245,24 @@ struct llmodel_gpu_device* llmodel_available_gpu_devices(llmodel_model model, si
bool llmodel_gpu_init_gpu_device_by_string(llmodel_model model, size_t memoryRequired, const char *device)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->initializeGPUDevice(memoryRequired, std::string(device));
}
bool llmodel_gpu_init_gpu_device_by_struct(llmodel_model model, const llmodel_gpu_device *device)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->initializeGPUDevice(device->index);
}
bool llmodel_gpu_init_gpu_device_by_int(llmodel_model model, int device)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->initializeGPUDevice(device);
}
bool llmodel_has_gpu_device(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->hasGPUDevice();
}

@ -186,13 +186,23 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
* NOTE: If given NULL pointers for the model or text, or an empty text, a NULL pointer will be
* returned. Bindings should signal an error when NULL is the return value.
* @param model A pointer to the llmodel_model instance.
* @param text A string representing the text to generate an embedding for.
* @param texts A pointer to a NULL-terminated array of strings representing the texts to generate an
* embedding for.
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length
* of the returned floating point array.
* @param prefix The model-specific prefix representing the embedding task, without the trailing colon. NULL for no
* prefix.
* @param dimensionality The embedding dimension, for use with Matryoshka-capable models. Set to -1 to for full-size.
* @param do_mean True to average multiple embeddings if the text is longer than the model can accept, False to
* truncate.
* @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with
* long_text_mode="mean" will raise an error. Disabled by default.
* @param error Return location for a malloc()ed string that will be set on error, or NULL.
* @return A pointer to an array of floating point values passed to the calling method which then will
* be responsible for lifetime of this memory.
* be responsible for lifetime of this memory. NULL if an error occurred.
*/
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size);
float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix,
int dimensionality, bool do_mean, bool atlas, const char **error);
/**
* Frees the memory allocated by the llmodel_embedding function.

@ -3,6 +3,7 @@
#include <cassert>
#include <iostream>
#include <regex>
#include <string>
#include <unordered_set>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
@ -267,12 +268,28 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
}
}
std::vector<float> LLModel::embedding(const std::string &text)
{
(void)text;
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
std::cerr << implementation().modelType() << errorMessage;
}
return std::vector<float>();
void LLModel::embed(
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
bool doMean, bool atlas
) {
(void)texts;
(void)embeddings;
(void)prefix;
(void)dimensionality;
(void)doMean;
(void)atlas;
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
}
void LLModel::embed(
const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality, bool doMean,
bool atlas
) {
(void)texts;
(void)embeddings;
(void)isRetrieval;
(void)dimensionality;
(void)doMean;
(void)atlas;
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
}

@ -10,7 +10,7 @@ import sys
import threading
from enum import Enum
from queue import Queue
from typing import Callable, Iterable, List
from typing import Callable, Iterable, overload
if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
@ -105,13 +105,18 @@ llmodel.llmodel_prompt.argtypes = [
llmodel.llmodel_prompt.restype = None
llmodel.llmodel_embedding.argtypes = [
llmodel.llmodel_embed.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_char_p),
ctypes.POINTER(ctypes.c_size_t),
ctypes.c_char_p,
ctypes.c_int,
ctypes.c_bool,
ctypes.c_bool,
ctypes.POINTER(ctypes.c_char_p),
]
llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float)
llmodel.llmodel_embed.restype = ctypes.POINTER(ctypes.c_float)
llmodel.llmodel_free_embedding.argtypes = [ctypes.POINTER(ctypes.c_float)]
llmodel.llmodel_free_embedding.restype = None
@ -287,16 +292,50 @@ class LLModel:
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
def generate_embedding(self, text: str) -> List[float]:
@overload
def generate_embeddings(
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
) -> list[float]: ...
@overload
def generate_embeddings(
self, text: list[str], prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
) -> list[list[float]]: ...
def generate_embeddings(self, text, prefix, dimensionality, do_mean, atlas):
if not text:
raise ValueError("Text must not be None or empty")
raise ValueError("text must not be None or empty")
single_text = isinstance(text, str)
if single_text:
text = [text]
# prepare input
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode())
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
error = ctypes.c_char_p()
c_prefix = ctypes.c_char_p() if prefix is None else prefix.encode()
c_texts = (ctypes.c_char_p * (len(text) + 1))()
for i, t in enumerate(text):
c_texts[i] = t.encode()
# generate the embeddings
embedding_ptr = llmodel.llmodel_embed(
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, do_mean, atlas,
ctypes.byref(error),
)
if embedding_ptr.value is None:
msg = "(unknown error)" if error.value is None else error.value.decode()
raise RuntimeError(f'Failed to generate embeddings: {msg}')
# extract output
n_embd = embedding_size.value // len(text)
embedding_array = [
embedding_ptr[i:i + n_embd]
for i in range(0, embedding_size.value, n_embd)
]
llmodel.llmodel_free_embedding(embedding_ptr)
return list(embedding_array)
return embedding_array[0] if single_text else embedding_array
def prompt_model(
self,

@ -10,7 +10,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union, overload
import requests
from requests.exceptions import ChunkedEncodingError
@ -36,6 +36,8 @@ class Embed4All:
Python class that handles embeddings for GPT4All.
"""
MIN_DIMENSIONALITY = 64
def __init__(self, model_name: Optional[str] = None, n_threads: Optional[int] = None, **kwargs):
"""
Constructor
@ -45,17 +47,48 @@ class Embed4All:
"""
self.gpt4all = GPT4All(model_name or 'all-MiniLM-L6-v2-f16.gguf', n_threads=n_threads, **kwargs)
def embed(self, text: str) -> List[float]:
@overload
def embed(
self, text: str, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
atlas: bool = ...,
) -> list[float]: ...
@overload
def embed(
self, text: list[str], prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
atlas: bool = ...,
) -> list[list[float]]: ...
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="truncate", atlas=False):
"""
Generate an embedding.
Generate one or more embeddings.
Args:
text: The text document to generate an embedding for.
text: A text or list of texts to generate embeddings for.
prefix: The model-specific prefix representing the embedding task, without the trailing colon. For Nomic
Embed this can be `search_query`, `search_document`, `classification`, or `clustering`.
dimensionality: The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size.
long_text_mode: How to handle texts longer than the model can accept. One of `mean` or `truncate`.
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
with long_text_mode="mean" will raise an error. Disabled by default.
Returns:
An embedding of your document of text.
An embedding or list of embeddings of your text(s).
"""
return self.gpt4all.model.generate_embedding(text)
if dimensionality is None:
dimensionality = -1
else:
if dimensionality <= 0:
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}')
if dimensionality < self.MIN_DIMENSIONALITY:
warnings.warn(
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.'
' Performance may be degraded.'
)
try:
do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
return self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas)
class GPT4All:

@ -202,8 +202,6 @@ install(TARGETS llamamodel-mainline-default DESTINATION lib COMPONENT ${COMPONEN
if(APPLE)
install(TARGETS llamamodel-mainline-metal DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
endif()
install(TARGETS bert-avxonly DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
install(TARGETS bert-default DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
set(CPACK_GENERATOR "IFW")
set(CPACK_VERBATIM_VARIABLES YES)

@ -12,7 +12,6 @@
#define GPTJ_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
#define BERT_INTERNAL_STATE_VERSION 0
class LLModelStore {
public:
@ -386,7 +385,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
switch (m_llModelInfo.model->implementation().modelType()[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'G': m_llModelType = LLModelType::GPTJ_; break;
case 'B': m_llModelType = LLModelType::BERT_; break;
default:
{
delete m_llModelInfo.model;
@ -840,7 +838,6 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
switch (m_llModelType) {
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
case BERT_: stream << BERT_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE();
}
}

@ -13,7 +13,6 @@ enum LLModelType {
GPTJ_,
LLAMA_,
CHATGPT_,
BERT_,
};
struct LLModelInfo {

@ -27,7 +27,7 @@ void EmbeddingLLMWorker::wait()
bool EmbeddingLLMWorker::loadModel()
{
const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels();
const EmbeddingModels *embeddingModels = ModelList::globalInstance()->installedEmbeddingModels();
if (!embeddingModels->count())
return false;
@ -41,7 +41,8 @@ bool EmbeddingLLMWorker::loadModel()
return false;
}
bool isNomic = fileInfo.fileName().startsWith("nomic");
auto filename = fileInfo.fileName();
bool isNomic = filename.startsWith("nomic-") && filename.endsWith(".txt");
if (isNomic) {
QFile file(filePath);
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
@ -52,16 +53,18 @@ bool EmbeddingLLMWorker::loadModel()
}
m_model = LLModel::Implementation::construct(filePath.toStdString());
// NOTE: explicitly loads model on CPU to avoid GPU OOM
// TODO(cebtenzzre): support GPU-accelerated embeddings
bool success = m_model->loadModel(filePath.toStdString(), 2048, 0);
if (!success) {
qWarning() << "WARNING: Could not load sbert";
qWarning() << "WARNING: Could not load embedding model";
delete m_model;
m_model = nullptr;
return false;
}
if (m_model->implementation().modelType() != "Bert") {
qWarning() << "WARNING: Model type is not sbert";
if (!m_model->supportsEmbedding()) {
qWarning() << "WARNING: Model type does not support embeddings";
delete m_model;
m_model = nullptr;
return false;
@ -79,21 +82,49 @@ bool EmbeddingLLMWorker::isNomic() const
return !m_nomicAPIKey.isEmpty();
}
// this function is always called for retrieval tasks
std::vector<float> EmbeddingLLMWorker::generateSyncEmbedding(const QString &text)
{
if (!hasModel() && !loadModel()) {
qWarning() << "WARNING: Could not load model for embeddings";
return std::vector<float>();
return {};
}
if (isNomic()) {
qWarning() << "WARNING: Request to generate sync embeddings for non-local model invalid";
return std::vector<float>();
return {};
}
return m_model->embedding(text.toStdString());
std::vector<float> embedding(m_model->embeddingSize());
try {
m_model->embed({text.toStdString()}, embedding.data(), true);
} catch (const std::exception &e) {
qWarning() << "WARNING: LLModel::embed failed: " << e.what();
return {};
}
return embedding;
}
void EmbeddingLLMWorker::sendAtlasRequest(const QStringList &texts, const QString &taskType, QVariant userData) {
QJsonObject root;
root.insert("model", "nomic-embed-text-v1");
root.insert("texts", QJsonArray::fromStringList(texts));
root.insert("task_type", taskType);
QJsonDocument doc(root);
QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text");
const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed();
QNetworkRequest request(nomicUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
request.setAttribute(QNetworkRequest::User, userData);
QNetworkReply *reply = m_networkManager->post(request, doc.toJson(QJsonDocument::Compact));
connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished);
}
// this function is always called for retrieval tasks
void EmbeddingLLMWorker::requestSyncEmbedding(const QString &text)
{
if (!hasModel() && !loadModel()) {
@ -108,25 +139,10 @@ void EmbeddingLLMWorker::requestSyncEmbedding(const QString &text)
Q_ASSERT(hasModel());
QJsonObject root;
root.insert("model", "nomic-embed-text-v1");
QJsonArray texts;
texts.append(text);
root.insert("texts", texts);
root.insert("task_type", "search_query");
QJsonDocument doc(root);
QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text");
const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed();
QNetworkRequest request(nomicUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
QNetworkReply *reply = m_networkManager->post(request, doc.toJson(QJsonDocument::Compact));
connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished);
sendAtlasRequest({text}, "search_query");
}
// this function is always called for storage into the database
void EmbeddingLLMWorker::requestAsyncEmbedding(const QVector<EmbeddingChunk> &chunks)
{
if (!hasModel() && !loadModel()) {
@ -141,33 +157,24 @@ void EmbeddingLLMWorker::requestAsyncEmbedding(const QVector<EmbeddingChunk> &ch
EmbeddingResult result;
result.folder_id = c.folder_id;
result.chunk_id = c.chunk_id;
result.embedding = m_model->embedding(c.chunk.toStdString());
// TODO(cebtenzzre): take advantage of batched embeddings
result.embedding.resize(m_model->embeddingSize());
try {
m_model->embed({c.chunk.toStdString()}, result.embedding.data(), false);
} catch (const std::exception &e) {
qWarning() << "WARNING: LLModel::embed failed:" << e.what();
return;
}
results << result;
}
emit embeddingsGenerated(results);
return;
};
QJsonObject root;
root.insert("model", "nomic-embed-text-v1");
QJsonArray texts;
for (auto c : chunks)
QStringList texts;
for (auto &c: chunks)
texts.append(c.chunk);
root.insert("texts", texts);
QJsonDocument doc(root);
QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text");
const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed();
QNetworkRequest request(nomicUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
request.setAttribute(QNetworkRequest::User, QVariant::fromValue(chunks));
QNetworkReply *reply = m_networkManager->post(request, doc.toJson(QJsonDocument::Compact));
connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished);
sendAtlasRequest(texts, "search_document", QVariant::fromValue(chunks));
}
std::vector<float> jsonArrayToVector(const QJsonArray &jsonArray) {

@ -1,10 +1,11 @@
#ifndef EMBLLM_H
#define EMBLLM_H
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QObject>
#include <QStringList>
#include <QThread>
#include <QNetworkReply>
#include <QNetworkAccessManager>
#include "../gpt4all-backend/llmodel.h"
@ -51,6 +52,8 @@ private Q_SLOTS:
void handleFinished();
private:
void sendAtlasRequest(const QStringList &texts, const QString &taskType, QVariant userData = {});
QString m_nomicAPIKey;
QNetworkAccessManager *m_networkManager;
std::vector<float> m_lastResponse;

@ -247,14 +247,31 @@
"filename": "all-MiniLM-L6-v2-f16.gguf",
"filesize": "45887744",
"requires": "2.5.0",
"removedIn": "2.7.4",
"ramrequired": "1",
"parameters": "40 million",
"quant": "f16",
"type": "Bert",
"embeddingModel": true,
"systemPrompt": " ",
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
},
{
"order": "o",
"md5sum": "dd90e2cb7f8e9316ac3796cece9883b5",
"name": "SBert",
"filename": "all-MiniLM-L6-v2.gguf2.f16.gguf",
"filesize": "45949216",
"requires": "2.7.4",
"ramrequired": "1",
"parameters": "40 million",
"quant": "f16",
"type": "Bert",
"embeddingModel": true,
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf"
},
{
"order": "p",
"md5sum": "919de4dd6f25351bcb0223790db1932d",
@ -270,5 +287,39 @@
"url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf",
"promptTemplate": "USER: %1 ASSISTANT: ",
"systemPrompt": "Du bist ein hilfreicher Assistent. "
},
{
"order": "q",
"md5sum": "60ea031126f82db8ddbbfecc668315d2",
"disableGUI": "true",
"name": "Nomic Embed Text v1",
"filename": "nomic-embed-text-v1.f16.gguf",
"filesize": "274290560",
"requires": "2.7.4",
"ramrequired": "1",
"parameters": "137 million",
"quant": "f16",
"type": "Bert",
"embeddingModel": true,
"systemPrompt": "",
"description": "nomic-embed-text-v1",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf"
},
{
"order": "r",
"md5sum": "a5401e7f7e46ed9fcaed5b60a281d547",
"disableGUI": "true",
"name": "Nomic Embed Text v1.5",
"filename": "nomic-embed-text-v1.5.f16.gguf",
"filesize": "274290560",
"requires": "2.7.4",
"ramrequired": "1",
"parameters": "137 million",
"quant": "f16",
"type": "Bert",
"embeddingModel": true,
"systemPrompt": "",
"description": "nomic-embed-text-v1.5",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf"
}
]

@ -10,8 +10,10 @@
//#define USE_LOCAL_MODELSJSON
#define DEFAULT_EMBEDDING_MODEL "all-MiniLM-L6-v2-f16.gguf"
#define NOMIC_EMBEDDING_MODEL "nomic-embed-text-v1.txt"
const char * const KNOWN_EMBEDDING_MODELS[] {
"all-MiniLM-L6-v2.gguf2.f16.gguf",
"nomic-embed-text-v1.txt",
};
QString ModelInfo::id() const
{
@ -223,6 +225,7 @@ void ModelInfo::setContextLength(int l)
int ModelInfo::maxContextLength() const
{
if (!installed || isOnline) return -1;
if (m_maxContextLength != -1) return m_maxContextLength;
auto path = (dirpath + filename()).toStdString();
int layers = LLModel::Implementation::maxContextLength(path);
@ -306,9 +309,11 @@ bool ModelInfo::shouldSaveMetadata() const
return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/);
}
EmbeddingModels::EmbeddingModels(QObject *parent)
EmbeddingModels::EmbeddingModels(QObject *parent, bool requireInstalled)
: QSortFilterProxyModel(parent)
{
m_requireInstalled = requireInstalled;
connect(this, &EmbeddingModels::rowsInserted, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::rowsRemoved, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::modelReset, this, &EmbeddingModels::countChanged);
@ -319,36 +324,41 @@ bool EmbeddingModels::filterAcceptsRow(int sourceRow,
const QModelIndex &sourceParent) const
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == DEFAULT_EMBEDDING_MODEL ||
sourceModel()->data(index, ModelList::FilenameRole).toString() == NOMIC_EMBEDDING_MODEL;
return isInstalled && isEmbedding;
bool isEmbeddingModel = sourceModel()->data(index, ModelList::IsEmbeddingModelRole).toBool();
bool installed = sourceModel()->data(index, ModelList::InstalledRole).toBool();
QString filename = sourceModel()->data(index, ModelList::FilenameRole).toString();
auto &known = KNOWN_EMBEDDING_MODELS;
if (std::find(known, std::end(known), filename.toStdString()) == std::end(known))
return false; // we are currently not prepared to support other embedding models
return isEmbeddingModel && (!m_requireInstalled || installed);
}
int EmbeddingModels::count() const
int EmbeddingModels::defaultModelIndex() const
{
return rowCount();
auto *sourceListModel = qobject_cast<const ModelList*>(sourceModel());
if (!sourceListModel) return -1;
int rows = sourceListModel->rowCount();
for (int i = 0; i < rows; ++i) {
if (filterAcceptsRow(i, sourceListModel->index(i, 0).parent()))
return i;
}
return -1;
}
ModelInfo EmbeddingModels::defaultModelInfo() const
{
if (!sourceModel())
return ModelInfo();
auto *sourceListModel = qobject_cast<const ModelList*>(sourceModel());
if (!sourceListModel) return ModelInfo();
const ModelList *sourceListModel = qobject_cast<const ModelList*>(sourceModel());
if (!sourceListModel)
return ModelInfo();
const int rows = sourceListModel->rowCount();
for (int i = 0; i < rows; ++i) {
QModelIndex sourceIndex = sourceListModel->index(i, 0);
if (filterAcceptsRow(i, sourceIndex.parent())) {
const QString id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString();
return sourceListModel->modelInfo(id);
}
}
int i = defaultModelIndex();
if (i < 0) return ModelInfo();
return ModelInfo();
QModelIndex sourceIndex = sourceListModel->index(i, 0);
auto id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString();
return sourceListModel->modelInfo(id);
}
InstalledModels::InstalledModels(QObject *parent)
@ -365,13 +375,9 @@ bool InstalledModels::filterAcceptsRow(int sourceRow,
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool();
return isInstalled && showInGUI;
}
int InstalledModels::count() const
{
return rowCount();
bool isEmbeddingModel = sourceModel()->data(index, ModelList::IsEmbeddingModelRole).toBool();
// list installed chat models
return isInstalled && !isEmbeddingModel;
}
DownloadableModels::DownloadableModels(QObject *parent)
@ -432,8 +438,9 @@ ModelList *ModelList::globalInstance()
ModelList::ModelList()
: QAbstractListModel(nullptr)
, m_embeddingModels(new EmbeddingModels(this))
, m_embeddingModels(new EmbeddingModels(this, false /* all models */))
, m_installedModels(new InstalledModels(this))
, m_installedEmbeddingModels(new EmbeddingModels(this, true /* installed models */))
, m_downloadableModels(new DownloadableModels(this))
, m_asyncModelRequestOngoing(false)
, m_discoverLimit(20)
@ -445,6 +452,7 @@ ModelList::ModelList()
{
m_embeddingModels->setSourceModel(this);
m_installedModels->setSourceModel(this);
m_installedEmbeddingModels->setSourceModel(this);
m_downloadableModels->setSourceModel(this);
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
@ -494,8 +502,8 @@ const QList<QString> ModelList::userDefaultModelList() const
bool foundUserDefault = false;
for (ModelInfo *info : m_models) {
// Only installed models that are meant for GUI are suitable as a default
if (!info->installed || info->disableGUI)
// Only installed chat models are suitable as a default
if (!info->installed || info->isEmbeddingModel)
continue;
if (info->id() == userDefaultModelName) {
@ -516,13 +524,7 @@ const QList<QString> ModelList::userDefaultModelList() const
int ModelList::defaultEmbeddingModelIndex() const
{
QMutexLocker locker(&m_mutex);
for (int i = 0; i < m_models.size(); ++i) {
const ModelInfo *info = m_models.at(i);
const bool isEmbedding = info->filename() == DEFAULT_EMBEDDING_MODEL;
if (isEmbedding) return i;
}
return -1;
return embeddingModels()->defaultModelIndex();
}
ModelInfo ModelList::defaultModelInfo() const
@ -692,8 +694,6 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->isDefault;
case OnlineRole:
return info->isOnline;
case DisableGUIRole:
return info->disableGUI;
case DescriptionRole:
return info->description();
case RequiresVersionRole:
@ -730,6 +730,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->isClone();
case IsDiscoveredRole:
return info->isDiscovered();
case IsEmbeddingModelRole:
return info->isEmbeddingModel;
case TemperatureRole:
return info->temperature();
case TopPRole:
@ -844,8 +846,6 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
info->isDefault = value.toBool(); break;
case OnlineRole:
info->isOnline = value.toBool(); break;
case DisableGUIRole:
info->disableGUI = value.toBool(); break;
case DescriptionRole:
info->setDescription(value.toString()); break;
case RequiresVersionRole:
@ -900,6 +900,8 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
}
break;
}
case IsEmbeddingModelRole:
info->isEmbeddingModel = value.toBool(); break;
case TemperatureRole:
info->setTemperature(value.toDouble()); break;
case TopPRole:
@ -952,11 +954,21 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
}
// Extra guarantee that these always remains in sync with filesystem
const QFileInfo fileInfo(info->dirpath + info->filename());
QString modelPath = info->dirpath + info->filename();
const QFileInfo fileInfo(modelPath);
info->installed = fileInfo.exists();
const QFileInfo incompleteInfo(incompleteDownloadPath(info->filename()));
info->isIncomplete = incompleteInfo.exists();
// check installed, discovered/sideloaded models only (including clones)
if (!info->checkedEmbeddingModel && !info->isEmbeddingModel && info->installed
&& (info->isDiscovered() || info->description().isEmpty()))
{
// read GGUF and decide based on model architecture
info->isEmbeddingModel = LLModel::Implementation::isEmbeddingModel(modelPath.toStdString());
info->checkedEmbeddingModel = true;
}
if (shouldSort) {
auto s = m_discoverSort;
auto d = m_discoverSortDirection;
@ -983,8 +995,11 @@ void ModelList::resortModel()
emit layoutChanged();
}
void ModelList::updateDataByFilename(const QString &filename, const QVector<QPair<int, QVariant>> &data)
void ModelList::updateDataByFilename(const QString &filename, QVector<QPair<int, QVariant>> data)
{
if (data.isEmpty())
return; // no-op
QVector<QString> modelsById;
{
QMutexLocker locker(&m_mutex);
@ -1041,6 +1056,7 @@ QString ModelList::clone(const ModelInfo &model)
{ ModelList::FilenameRole, model.filename() },
{ ModelList::DirpathRole, model.dirpath },
{ ModelList::OnlineRole, model.isOnline },
{ ModelList::IsEmbeddingModelRole, model.isEmbeddingModel },
{ ModelList::TemperatureRole, model.temperature() },
{ ModelList::TopPRole, model.topP() },
{ ModelList::MinPRole, model.minP() },
@ -1164,8 +1180,7 @@ void ModelList::updateModelsFromDirectory()
if (!it.fileInfo().isDir()) {
QString filename = it.fileName();
// All files that end with .bin and have 'ggml' somewhere in the name
if (((filename.endsWith(".bin") || filename.endsWith(".gguf")) && (/*filename.contains("ggml") ||*/ filename.contains("gguf")) && !filename.startsWith("incomplete"))
if ((filename.endsWith(".gguf") && !filename.startsWith("incomplete"))
|| (filename.endsWith(".txt") && (filename.startsWith("chatgpt-") || filename.startsWith("nomic-")))) {
QString filePath = it.filePath();
@ -1373,16 +1388,19 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
QString parameters = obj["parameters"].toString();
QString quant = obj["quant"].toString();
QString type = obj["type"].toString();
bool isEmbeddingModel = obj["embeddingModel"].toBool();
// Some models aren't supported in the GUI at all
if (disableGUI)
continue;
// If the current version is strictly less than required version, then skip
if (!requiresVersion.isEmpty() && compareVersions(currentVersion, requiresVersion) < 0) {
if (!requiresVersion.isEmpty() && compareVersions(currentVersion, requiresVersion) < 0)
continue;
}
// If the version removed is less than or equal to the current version, then skip
if (!versionRemoved.isEmpty() && compareVersions(versionRemoved, currentVersion) <= 0) {
if (!versionRemoved.isEmpty() && compareVersions(versionRemoved, currentVersion) <= 0)
continue;
}
modelFilesize = ModelList::toFileSize(modelFilesize.toULongLong());
@ -1406,12 +1424,12 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
{ ModelList::RequiresVersionRole, requiresVersion },
{ ModelList::VersionRemovedRole, versionRemoved },
{ ModelList::UrlRole, url },
{ ModelList::DisableGUIRole, disableGUI },
{ ModelList::OrderRole, order },
{ ModelList::RamrequiredRole, ramrequired },
{ ModelList::ParametersRole, parameters },
{ ModelList::QuantRole, quant },
{ ModelList::TypeRole, type },
{ ModelList::IsEmbeddingModelRole, isEmbeddingModel },
};
if (obj.contains("temperature"))
data.append({ ModelList::TemperatureRole, obj["temperature"].toDouble() });
@ -1515,7 +1533,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
{ ModelList::FilenameRole, modelFilename },
{ ModelList::FilesizeRole, "minimal" },
{ ModelList::OnlineRole, true },
{ ModelList::DisableGUIRole, true },
{ ModelList::IsEmbeddingModelRole, true },
{ ModelList::DescriptionRole,
tr("<strong>LocalDocs Nomic Atlas Embed</strong><br>") + nomicEmbedDesc },
{ ModelList::RequiresVersionRole, "2.6.3" },

@ -16,7 +16,6 @@ struct ModelInfo {
Q_PROPERTY(bool calcHash MEMBER calcHash)
Q_PROPERTY(bool installed MEMBER installed)
Q_PROPERTY(bool isDefault MEMBER isDefault)
Q_PROPERTY(bool disableGUI MEMBER disableGUI)
Q_PROPERTY(bool isOnline MEMBER isOnline)
Q_PROPERTY(QString description READ description WRITE setDescription)
Q_PROPERTY(QString requiresVersion MEMBER requiresVersion)
@ -36,6 +35,7 @@ struct ModelInfo {
Q_PROPERTY(QString type READ type WRITE setType)
Q_PROPERTY(bool isClone READ isClone WRITE setIsClone)
Q_PROPERTY(bool isDiscovered READ isDiscovered WRITE setIsDiscovered)
Q_PROPERTY(bool isEmbeddingModel MEMBER isEmbeddingModel)
Q_PROPERTY(double temperature READ temperature WRITE setTemperature)
Q_PROPERTY(double topP READ topP WRITE setTopP)
Q_PROPERTY(double minP READ minP WRITE setMinP)
@ -104,7 +104,6 @@ public:
bool installed = false;
bool isDefault = false;
bool isOnline = false;
bool disableGUI = false;
QString requiresVersion;
QString versionRemoved;
qint64 bytesReceived = 0;
@ -117,6 +116,8 @@ public:
QString order;
int ramrequired = -1;
QString parameters;
bool isEmbeddingModel = false;
bool checkedEmbeddingModel = false;
bool operator==(const ModelInfo &other) const {
return m_id == other.m_id;
@ -187,9 +188,10 @@ class EmbeddingModels : public QSortFilterProxyModel
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit EmbeddingModels(QObject *parent);
int count() const;
EmbeddingModels(QObject *parent, bool requireInstalled);
int count() const { return rowCount(); }
int defaultModelIndex() const;
ModelInfo defaultModelInfo() const;
Q_SIGNALS:
@ -198,6 +200,9 @@ Q_SIGNALS:
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
private:
bool m_requireInstalled;
};
class InstalledModels : public QSortFilterProxyModel
@ -206,7 +211,7 @@ class InstalledModels : public QSortFilterProxyModel
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit InstalledModels(QObject *parent);
int count() const;
int count() const { return rowCount(); }
Q_SIGNALS:
void countChanged();
@ -248,8 +253,8 @@ class ModelList : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex NOTIFY defaultEmbeddingModelIndexChanged)
Q_PROPERTY(EmbeddingModels* embeddingModels READ embeddingModels NOTIFY embeddingModelsChanged)
Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex)
Q_PROPERTY(EmbeddingModels* installedEmbeddingModels READ installedEmbeddingModels NOTIFY installedEmbeddingModelsChanged)
Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged)
Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged)
Q_PROPERTY(QList<QString> userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged)
@ -282,7 +287,6 @@ public:
InstalledRole,
DefaultRole,
OnlineRole,
DisableGUIRole,
DescriptionRole,
RequiresVersionRole,
VersionRemovedRole,
@ -301,6 +305,7 @@ public:
TypeRole,
IsCloneRole,
IsDiscoveredRole,
IsEmbeddingModelRole,
TemperatureRole,
TopPRole,
TopKRole,
@ -332,7 +337,6 @@ public:
roles[InstalledRole] = "installed";
roles[DefaultRole] = "isDefault";
roles[OnlineRole] = "isOnline";
roles[DisableGUIRole] = "disableGUI";
roles[DescriptionRole] = "description";
roles[RequiresVersionRole] = "requiresVersion";
roles[VersionRemovedRole] = "versionRemoved";
@ -351,6 +355,7 @@ public:
roles[TypeRole] = "type";
roles[IsCloneRole] = "isClone";
roles[IsDiscoveredRole] = "isDiscovered";
roles[IsEmbeddingModelRole] = "isEmbeddingModel";
roles[TemperatureRole] = "temperature";
roles[TopPRole] = "topP";
roles[MinPRole] = "minP";
@ -373,7 +378,7 @@ public:
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override;
QVariant data(const QString &id, int role) const;
QVariant dataByFilename(const QString &filename, int role) const;
void updateDataByFilename(const QString &filename, const QVector<QPair<int, QVariant>> &data);
void updateDataByFilename(const QString &filename, QVector<QPair<int, QVariant>> data);
void updateData(const QString &id, const QVector<QPair<int, QVariant>> &data);
int count() const { return m_models.size(); }
@ -396,6 +401,7 @@ public:
const QList<QString> userDefaultModelList() const;
EmbeddingModels *embeddingModels() const { return m_embeddingModels; }
EmbeddingModels *installedEmbeddingModels() const { return m_installedEmbeddingModels; }
InstalledModels *installedModels() const { return m_installedModels; }
DownloadableModels *downloadableModels() const { return m_downloadableModels; }
@ -433,12 +439,11 @@ public:
Q_SIGNALS:
void countChanged();
void embeddingModelsChanged();
void installedEmbeddingModelsChanged();
void installedModelsChanged();
void downloadableModelsChanged();
void userDefaultModelListChanged();
void asyncModelRequestOngoingChanged();
void defaultEmbeddingModelIndexChanged();
void discoverLimitChanged();
void discoverSortDirectionChanged();
void discoverSortChanged();
@ -474,6 +479,7 @@ private:
mutable QMutex m_mutex;
QNetworkAccessManager m_networkManager;
EmbeddingModels *m_embeddingModels;
EmbeddingModels *m_installedEmbeddingModels;
InstalledModels *m_installedModels;
DownloadableModels *m_downloadableModels;
QList<ModelInfo*> m_models;
@ -488,7 +494,7 @@ private:
protected:
explicit ModelList();
~ModelList() {}
~ModelList() { for (auto *model: m_models) { delete model; } }
friend class MyModelList;
};

@ -14,7 +14,7 @@ MySettingsTab {
MySettings.restoreLocalDocsDefaults();
}
property bool hasEmbeddingModel: ModelList.embeddingModels.count !== 0
property bool hasEmbeddingModel: ModelList.installedEmbeddingModels.count !== 0
showAdvancedSettingsButton: hasEmbeddingModel
showRestoreDefaultsButton: hasEmbeddingModel

@ -24,7 +24,7 @@ MyDialog {
if (showEmbeddingModels) {
ModelList.downloadableModels.expanded = true
var targetModelIndex = ModelList.defaultEmbeddingModelIndex
modelListView.positionViewAtIndex(targetModelIndex, ListView.Contain)
modelListView.positionViewAtIndex(targetModelIndex, ListView.Beginning)
}
}

Loading…
Cancel
Save