Skip to content
Navigation Menu
{{ message }}
forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest-save-load-state.cpp
More file actions
358 lines (277 loc) · 12 KB
/
Copy pathtest-save-load-state.cpp
File metadata and controls
358 lines (277 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama-cpp.h"
#include <clocale>
#include <random>
#include <vector>
struct llama_batch_ptr {
llama_batch batch;
llama_batch_ptr(int32_t n_tokens, int32_t embd, int32_t n_seq_max)
: batch{llama_batch_init(n_tokens, embd, n_seq_max)} {}
~llama_batch_ptr() { llama_batch_free(batch); }
llama_batch_ptr(const llama_batch_ptr &) = delete;
llama_batch_ptr & operator=(const llama_batch_ptr &) = delete;
llama_batch_ptr(llama_batch_ptr &&) = default;
llama_batch_ptr & operator=(llama_batch_ptr &&) = default;
llama_batch & get() { return batch; }
const llama_batch & get() const { return batch; }
};
static llama_tokens generate_tokens(llama_context * ctx, llama_sampler * smpl, int & n_past, int32_t n_predict, llama_seq_id seq_id) {
llama_tokens result;
llama_batch_ptr batch(1, 0, 1);
for (int i = 0; i < n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
LOG("%d ", next_token);
result.push_back(next_token);
common_batch_clear(batch.get());
common_batch_add(batch.get(), next_token, n_past, {seq_id}, true);
if (llama_decode(ctx, batch.get())) {
LOG_ERR("\n%s: failed to evaluate\n", __func__);
return {};
}
n_past++;
}
return result;
}
// Test 1: baseline
// - decode all but the last token
// - save state to disk
// - decode the last token
// - generate n_predict tokens
static llama_tokens test_baseline(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens) {
auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};
auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));
auto n_past = 0;
if (!common_prompt_batch_decode(ctx.get(), tokens, (int)tokens.size(), n_past, params.n_batch, params.out_file, true)) {
LOG_ERR("%s: failed to decode prompt\n", __func__);
return {};
}
LOG("\n=== Test 1: baseline ===\n");
auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 0);
if (result.empty()) {
return {};
}
LOG("\n");
return result;
}
// Test 2: state load
// - create a new context
// - load state from file
// - replay the last prompt token
// - generate n_predict tokens and compare against expected result
static bool test_state_load(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};
auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));
LOG("\n=== Test 2: state load ===\n");
// Load state from file
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;
if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
LOG_ERR("\n%s: failed to load state\n", __func__);
return false;
}
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);
// Replay last token
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
n_past++;
// Generate tokens
auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 0);
if (result.empty()) {
return false;
}
if (result != expected_result) {
LOG_ERR("\n%s: error: generation differs from expected\n", __func__);
return false;
}
LOG("\nPASS\n");
return true;
}
// Test 3: seq copy (host)
// - create a multi-seq context
// - load state from file
// - replay the last prompt token
// - migrate KV cache from seq 0 to seq 1 via the CPU path
// - generate n_predict tokens on seq 1 and compare against expected result
static bool test_seq_cp_host(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto params_ctx = common_context_params_to_llama(params);
params_ctx.n_seq_max = 2;
auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)};
auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));
LOG("\n=== Test 3: seq copy (host) ===\n");
// Load state from file
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;
if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
LOG_ERR("\n%s: failed to load state\n", __func__);
return false;
}
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);
// Replay last token
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
n_past++;
// Migrate KV cache from seq 0 to seq 1 (CPU path)
{
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx.get(), 0));
const size_t ncopy = llama_state_seq_get_data(ctx.get(), seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
LOG_ERR("\n%s: seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
return false;
}
LOG_TRC("%s: seq 0 copied, %zd bytes\n", __func__, ncopy);
llama_memory_clear(llama_get_memory(ctx.get()), true);
LOG_TRC("%s: kv cache cleared\n", __func__);
const size_t nset = llama_state_seq_set_data(ctx.get(), seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
LOG_ERR("\n%s: seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
return false;
}
LOG_TRC("%s: seq 1 restored, %zd bytes\n", __func__, nset);
}
// Generate tokens on seq 1
auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 1);
if (result.empty()) {
return false;
}
if (result != expected_result) {
LOG_ERR("\n%s: error: generation differs from expected\n", __func__);
return false;
}
LOG("\nPASS\n");
return true;
}
// Test 4: seq copy (device)
// - create a multi-seq context
// - load state from file
// - replay the last prompt token
// - migrate KV cache from seq 0 to seq 1 via the on-device path
// - generate n_predict tokens on seq 1 and compare against expected result
static bool test_seq_cp_device(struct llama_model * model, const struct common_params & params, const llama_tokens & tokens, const llama_tokens & expected_result) {
auto params_ctx = common_context_params_to_llama(params);
params_ctx.n_seq_max = 2;
auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)};
auto sparams = llama_sampler_chain_default_params();
auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)};
llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed));
LOG("\n=== Test 4: seq copy (device) ===\n");
// Load state from file
llama_tokens unused_sts(tokens.size());
size_t n_token_count_out = 0;
if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
LOG_ERR("\n%s: failed to load state\n", __func__);
return false;
}
LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out);
// Replay last token
int n_past = (int) n_token_count_out - 1;
if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) {
return false;
}
n_past++;
// Migrate KV cache from seq 0 to seq 1 (on-device path)
{
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx.get(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
const size_t ncopy = llama_state_seq_get_data_ext(ctx.get(), seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (ncopy != seq_store.size()) {
LOG_ERR("\n%s: seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
return false;
}
LOG_TRC("%s: seq 0 copied, %zd bytes\n", __func__, ncopy);
llama_memory_clear(llama_get_memory(ctx.get()), true);
LOG_TRC("%s: kv cache cleared\n", __func__);
const size_t nset = llama_state_seq_set_data_ext(ctx.get(), seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (nset != seq_store.size()) {
LOG_ERR("\n%s: seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
return false;
}
LOG_TRC("%s: seq 1 restored, %zd bytes\n", __func__, nset);
}
// Generate tokens on seq 1
auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 1);
if (result.empty()) {
return false;
}
if (result != expected_result) {
LOG_ERR("\n%s: error: generation differs from expected\n", __func__);
return false;
}
LOG("\nPASS\n");
return true;
}
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
common_params params;
params.prompt = "";
params.n_batch = 100;
params.out_file = "dump_state.bin";
params.sampling.seed = 1234;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
if (params.n_parallel == 1) {
LOG_TRC("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
params.kv_unified = true;
}
if (params.n_predict < 0) {
params.n_predict = 16;
}
ggml_backend_load_all();
auto llama_init = common_init_from_params(params, true);
auto * model = llama_init->model();
if (model == nullptr) {
LOG_ERR("%s: failed to init\n", __func__);
return 1;
}
GGML_ASSERT(llama_init->context() == nullptr);
// Tokenize prompt or generate random tokens
llama_tokens tokens;
if (params.prompt.empty()) {
const int n_prompt = params.n_batch;
// this path is useful for model files that do not have a tokenizer
LOG_INF("%s: no prompt provided, generating %d (n_batch) random tokens\n", __func__, n_prompt);
const auto * vocab = llama_model_get_vocab(model);
const auto n_vocab = llama_vocab_n_tokens(vocab);
std::mt19937 rng(params.sampling.seed);
std::uniform_int_distribution<llama_token> dist(0, n_vocab - 1);
for (int i = 0; i < n_prompt; i++) {
tokens.push_back(dist(rng));
}
} else {
LOG_INF("%s: tokenizing prompt '%s'\n", __func__, params.prompt.c_str());
auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))};
tokens = common_tokenize(ctx.get(), params.prompt, true);
}
LOG_INF("%s: the input prompt is %d tokens\n", __func__, (int)tokens.size());
// Test 1: baseline (saves state to disk)
auto result_baseline = test_baseline(model, params, tokens);
if (result_baseline.empty()) {
return 1;
}
// Test 2: state load
if (!test_state_load(model, params, tokens, result_baseline)) {
return 1;
}
// Test 3: seq copy (host)
if (!test_seq_cp_host(model, params, tokens, result_baseline)) {
return 1;
}
// Test 4: seq copy (device)
if (!test_seq_cp_device(model, params, tokens, result_baseline)) {
return 1;
}
LOG("\nAll tests passed.\n");
return 0;
}
You can’t perform that action at this time.
