Skip to content

Commit f5483ea

Browse files
committed
lora: Lycoris LoHa support + refactor a bit
1 parent 0bb048c commit f5483ea

File tree

2 files changed

+461
-368
lines changed

2 files changed

+461
-368
lines changed

ggml_extend.hpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@
5252
#define __STATIC_INLINE__ static inline
5353
#endif
5454

55-
__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only, const char* mark);
56-
5755
// n-mode trensor-matrix product
5856
// example: 2-mode product
5957
// A: [ne03, k, ne01, ne00]
@@ -62,7 +60,7 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
6260
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
6361
// reshape A
6462
// swap 0th and nth axis
65-
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
63+
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
6664
int ne1 = a->ne[1];
6765
int ne2 = a->ne[2];
6866
int ne3 = a->ne[3];
@@ -78,6 +76,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx,
7876
return result;
7977
}
8078

79+
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
80+
struct ggml_tensor* updown;
81+
// flat lora tensors to multiply it
82+
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
83+
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
84+
auto lora_down_n_dims = ggml_n_dims(lora_down);
85+
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
86+
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
87+
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
88+
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
89+
90+
// ggml_mul_mat requires tensor b transposed
91+
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
92+
if (lora_mid == NULL) {
93+
updown = ggml_mul_mat(ctx, lora_up, lora_down);
94+
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
95+
} else {
96+
// undoing tucker decomposition for conv layers.
97+
// lora_mid has shape (3, 3, Rank, Rank)
98+
// lora_down has shape (Rank, In, 1, 1)
99+
// lora_up has shape (Rank, Out, 1, 1)
100+
// conv layer shape is (3, 3, Out, In)
101+
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
102+
updown = ggml_cont(ctx, updown);
103+
}
104+
return updown;
105+
}
106+
81107
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
82108
(void)level;
83109
(void)user_data;
@@ -1013,8 +1039,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
10131039
}
10141040

10151041
/* SDXL with LoRA requires more space */
1016-
#define MAX_PARAMS_TENSOR_NUM 15360
1017-
#define MAX_GRAPH_SIZE 15360
1042+
#define MAX_PARAMS_TENSOR_NUM 16384
1043+
#define MAX_GRAPH_SIZE 16384
10181044

10191045
struct GGMLRunner {
10201046
protected:

0 commit comments

Comments
 (0)