TL;DR Il caricamento dei modelli ML è lento, anche con la cache delle pagine Linux già calda. Quindi abbiamo creato una libreria per renderlo veloce. Ci sono alcuni dettagli tecnici interessanti che vogliamo condividere, quindi abbiamo scritto questo blog. La libreria ha avuto anche un impatto inaspettato, discusso alla fine.
Motivazione
Tutto inizia 2 anni fa, quando abbiamo lanciato il nostro primo tentativo di modalità di generazione lowpoly. La modalità lowpoly non è andata bene, emette risultati scarsi dal punto di vista odierno, ma abbiamo pagato molto per essa -- una GPU dedicata elabora solo un numero a una cifra di compiti al giorno. Ha pesi finemente sintonizzati, abbastanza grandi da spingere fuori dalla VRAM tutti gli altri pesi del modello. Peggio ancora, abbiamo forse 3 di tali modelli (non ricordo il numero esatto), costituivano una parte significativa della nostra infrastruttura di inferenza, creando un rapporto di efficienza piuttosto severo. E no, non possiamo caricare ingenuamente i modelli just-in-time, costa 30 secondi, più del tempo effettivo di elaborazione.
Non avevamo ingegneri di pipeline dedicati allora, i nostri sviluppatori di algoritmi hanno fatto del loro meglio per aggirare questo problema. Giorni dopo, il nostro codice era disseminato di this.to('cpu') e that.to('cuda'). Questo approccio funziona per un po', ma interrompe il flusso dei nostri sviluppatori di algoritmi di tanto in tanto. E se le cose potessero accadere in modo automatico? È Python, le cose accadono in modo automatico in Python.
Come definisci 'automatico'?
Passiamo al ruolo di uno sviluppatore di algoritmi. Le cose sono abbastanza chiare: non voglio preoccuparmi delle prestazioni di runtime al di fuori del mio algoritmo principale a meno che non sia assolutamente necessario. Preferirei non sapere nulla sullo scambio di modelli dentro e fuori.
Ovviamente non possiamo raggiungere questo obiettivo, ma possiamo cercare di minimizzare l'intrusione che dobbiamo introdurre nel codice dell'algoritmo. Questo mi ricorda il monkey-patching della libreria gevent, che modifica (principalmente) la libreria socket, sostituendola con gevent.socket che può passare ad altri greenlet quando l'IO bloccherebbe, molto simile a una goroutine (in realtà gevent è più vecchio di Golang!).
Poiché stavamo usando solo le librerie di HuggingFace (transformers, diffusers) per caricare i modelli al tempo, l'obiettivo è diventato chiaro: introduciamo solo una chiamata di monkey-patch e il resto del codice dovrebbe rimanere invariato, XXXPipeline.from_pretrained(...) dovrebbe essere molto più veloce.
Alcuni Fatti, Decisioni Ovvie e Assunzioni
Overmind è una libreria di caching, memorizza i risultati delle chiamate di caricamento dei modelli nella memoria di sistema e li ricostruisce rapidamente.
Saltiamo la discussione su come è implementato il monkey-patching, è un dettaglio non così interessante. Tutto ciò che dobbiamo sapere è che reindirizza tutte le chiamate XXXPipeline.from_pretrained(...) a overmind.api.load(XXXPipeline.from_pretrained, ...).
Usiamo pickle per serializzare il risultato della cache poiché... non abbiamo scelta, e torch.save stesso usa pickle, è strano non usarlo.
Usiamo un'architettura client/server poiché non vogliamo invalidare la nostra cache quando il processo termina. Ci sono molte chiamate a sottoprocessi che potrebbero beneficiarne.
Assumiamo che i parametri di XXXPipeline.from_pretrained siano cose semplici hashabili (str e cose simili) e altri modelli caricati da overmind (spiegato più avanti).
Il nome overmind è preso in prestito da Starcraft, come avrete intuito.
Ricostruirlo velocemente!
Non possiamo ingenuamente salvare il risultato di pickle.loads in memoria e considerarlo fatto. Dopotutto, in uno scenario riscaldato, la cache delle pagine Linux ha fatto il suo lavoro memorizzando i modelli su disco e possiamo ancora vedere un tempo di caricamento misurato in decine di secondi.
L'inefficienza deriva dalla copia della memoria. In Python, anche creare milioni di oggetti costerebbe non più di alcune centinaia di millisecondi. Tuttavia, per una copia di memoria di 10 GiB, costerebbe mezzo secondo. Dobbiamo evitare la copia della memoria il più possibile.
Fortunatamente, la maggior parte dei grandi blocchi di memoria sono tensori di Torch, possiamo tranquillamente indirizzare solo loro e ignorare il resto.
In realtà, ho acquisito la conoscenza della struttura interna di un tensore Torch nel codice di riduzione mentre ricercavo il meccanismo di condivisione dei tensori:
# Copiato da torch.multiprocessing.reductions, la maggior parte del codice è stata rimossa
def reduce_tensor(tensor):
...
storage = tensor._typed_storage()
...
metadata = (
tensor.storage_offset(),
tensor.size(),
tensor.stride(),
tensor.requires_grad,
)
return (rebuild_tensor, (type(tensor), storage, metadata))Abbastanza semplice: un tensore è il suo tipo, i suoi metadati e il suo storage sottostante. Qui storage è di tipo TypedStorage, ma in realtà TypedStorage è solo un semplice wrapper per UntypedStorage. UntypedStorage è la classe che effettivamente contiene tutti i dati del tensore.
Il nostro compito diventa ora più specifico: come possiamo evitare di copiare UntypedStorage? Possiamo gestire noi stessi la memoria di questi tensori e costruire UntypedStorage puntando alla memoria che gestiamo?
La risposta è sì!
Scorrendo il codice C++ dove UntypedStorage è costruito, possiamo facilmente trovare un frammento di codice come questo:
// Copiato da torch/csrc/Storage.cpp
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
// ...omettendo codice non correlato...
auto new_storage_impl = make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
slicelength,
at::DataPtr(
static_cast<void*>(data + start),
old_storage_impl,
[](void* s) {
c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
},
old_storage_impl->device()),
old_storage_impl->allocator(),
/* resizable */ false,
device_opt);
PyObject* _ret =
THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl));
return _ret;
}Non solo possiamo usare un puntatore, ma la classe at::DataPtr può anche gestire la distruzione, rendendo la gestione del ciclo di vita molto più semplice.
Dal lato Python, un puntatore a una regione di memoria è rappresentato da un oggetto memoryview, questi oggetti supportano il protocollo buffer. Possiamo ottenere un oggetto memoryview da molte cose, bytes e mmap sono le 2 principali cose che lo supportano, e sono anche quelle che ci interessano.
Infine, sappiamo cosa dovremmo fare: creare una funzione che accetti un oggetto memoryview e lo trasformi in un UntypedStorage senza copiarlo. Con la capacità di ricostruire UntypedStorage da memoryview, i dati effettivi del tensore non devono essere nel flusso pickle, riducendo notevolmente la dimensione dei dati che dobbiamo copiare.
void initOvermindHelpers(py::module m) {
// ...
m.def("_make_untyped_storage", [](py::buffer b) {
auto info = new py::buffer_info(b.request());
return pybind11::reinterpret_steal<py::object>(THPStorage_NewWithStorage(
THPStorageClass,
c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
info->size,
at::DataPtr(
info->ptr, info,
[](void* ptr) {
py::gil_scoped_acquire gil;
auto b = static_cast<py::buffer_info*>(ptr);
delete b;
},
at::DeviceType::CPU
),
/*allocator=*/nullptr,
/*resizable=*/false,
)
));
});
}Questo è il blocco costitutivo principale di overmind.
Condivisione dei tensori!
Nota: Esiste già un meccanismo di condivisione dei tensori in PyTorch, ma non soddisfa le nostre esigenze. Maggiori dettagli su questo più avanti.
Prima di tutto, condivisione della memoria tra client e server
Quando vediamo 'condivisione' e 'memoria' insieme, abbiamo tutti l'impulso di usare shmget e i suoi amici. È "progettato" per essere utilizzato come meccanismo di condivisione della memoria, giusto? Ma ha 2 difetti principali:
- POSIX shm è una risorsa scarsa, ciò che puoi usare è determinato da come l'amministratore di sistema configura il sistema. Un esempio estremo ma onnipresente sono i container Docker, per impostazione predefinita hai solo 64MiB di shm POSIX utilizzabile.
- La memoria condivisa POSIX vive più a lungo del tuo processo, devi gestirla autonomamente. Se il processo di gestione viene terminato forzatamente o non viene gestito con cura, l'oggetto shm potrebbe rimanere nel sistema indefinitamente.
Se guardi attentamente, Linux è pieno di chiamate di sistema interessanti. memfd_create è una di quelle che ci interessa: ti fornisce un fd che rappresenta un'allocazione di memoria anonima. Puoi eseguire tutti i tipi di operazioni sui file: leggere, scrivere e, naturalmente, mmap. Se possiamo condividere il fd, possiamo condividere la memoria.
Condividere un fd ha un modo 'standard' ma arcano per farlo: sendmsg con SCM_RIGHTS. Possiamo sfruttare le librerie per aiutarci a nascondere i dettagli scoraggianti del processo sendmsg, ma dobbiamo comunque coordinare i processi tra server e client. Abbiamo deciso di usare un trucco qui: basta aprire /proc/{pidof(server)}/fd/{memfd} sul lato client, senza mai chiudere il fd sul lato server overmind. L'unica comunicazione necessaria è una tupla (pid, fd). Funziona perfettamente nel nostro caso.
Le parole sopra si riducono a queste righe:
class SharedMemory:
@classmethod
def create(cls, shift):
# Chiamato sul lato server
libc = ctypes.CDLL(None)
name = _make_filename(shift).encode('utf-8')
fd = libc.memfd_create(name, os.O_RDWR)
os.ftruncate(fd, 1 << shift)
mem_id = (os.getpid(), fd)
return cls(fd=fd, mem_id=mem_id)
@classmethod
def rebuild(cls, mem_id):
# Chiamato sul lato client
pid, fd = mem_id
local_fd = os.open(f'/proc/{pid}/fd/{fd}', os.O_RDWR)
return cls(fd=local_fd, mem_id=mem_id)
def get_buffer(self):
# Chiamato su entrambi i lati
self._mmap = mmap.mmap(self._fd, size)
self._buf = memoryview(self._mmap)
return self._bufIntegrazione con il pickling
Come abbiamo discusso prima, dobbiamo modificare il processo di pickling di UntypedStorage. Simile a quanto implementato in torch.multiprocessing.reductions, definiamo le nostre funzioni di riduzione personalizzate per pickle:
# Hoarder e borrower sono un wrapper per SharedMemory sopra, contengono
# cose noiose come l'arena di memoria, ecc.
def _reduce_storage(storage):
# Chiamato dal server
device = storage.device
storage = storage.cpu()
# Memorizza il contenuto nella memoria condivisa
# Il `frag` contiene tutte le informazioni necessarie per localizzare il contenuto.
frag = hoarder.put(storage)
return (_rebuild_storage_on_client, (frag, device))
def _rebuild_storage_on_client(frag, device):
# Chiamato dal client
mv = borrower.borrow(frag) # Ottieni una memoryview dalla memoria condivisa
storage = _make_untyped_storage(mv) # Zero-copy!
if device.type == 'cuda':
return storage.cuda(device.index)
return storage
class OvermindPickler(dill.Pickler):
...
OvermindPickler.register(torch.storage.UntypedStorage, _reduce_storage)Ora, semplici OvermindPickler.dumps e OvermindPickler.loads utilizzeranno la memoria condivisa per velocizzare. Puoi smettere di leggere qui se sei già stufo. Il resto sono dettagli.
I dettagli del diavolo
Perché non la condivisione di tensori interna di PyTorch?
Per 'condivisione di tensori interna', intendo torch.multiprocessing.reductions.
- Ad alto livello, il metodo di PyTorch è progettato per 'passare il tensore al sottoprocesso', sembra lo stesso ma con una differenza sottile.
- PyTorch utilizza la memoria condivisa POSIX per condividere la memoria, soggetta al limite menzionato in precedenza.
- Per ogni tensore (o
UntypedStorage), PyTorch alloca un oggetto shm POSIX dedicato, anche se contiene solo 4 byte. Ogni oggetto consuma un fd. - PyTorch dealloca la memoria condivisa POSIX una volta che vengono unpickled, rendendola inadatta alle nostre esigenze. Abbiamo bisogno di deserializzare lo stesso flusso di pickle più volte.
- Ci sono molte logiche di condivisione relative a CUDA, che sono puro rumore e problemi per il nostro caso d'uso.
Perché dici che 'i dati del tensore vengono copiati più volte'?
Per un tipico torch.load su disco:
- Il file
torch.savesu disco viene letto in memoria. - Ottieni i dati effettivi di
torch.UntypedStoragecomebytestramite l'estrazione di un file Zip (il comandotorch.savegenera un file zip). - Il codice C++ copierà i dati nella propria memoria gestita nel costruttore di
torch.UntypedStorage.
Per un semplice pickle.dumps e successivamente pickle.loads:
- Lo stream pickle generato incorpora internamente un altro stream pickle,
pickle.loadscopierà lo stream interno in un nuovobytes. - I dati di
torch.UntypedStoragesono incorporati nello stream pickle interno, un'altra copia avviene durante la costruzione ditorch.UntypedStorage. - Il codice C++ copierà i dati nella propria memoria gestita nel costruttore di
torch.UntypedStorage.
diffusers ha un modulo dinamico
I repository dei modelli possono includere file Python che vengono importati a runtime in uno spazio dei nomi diffusers_modules. Il client non li ha in sys.path, interrompendo l'unpickling. Fortunatamente, diffusers scriverà questi file Python dinamici su disco, quindi possiamo semplicemente importare il modulo e risolvere la questione.
def diffusers_dyn_module_workaround():
from diffusers.utils.constants import HF_MODULES_CACHE
modpath = Path(HF_MODULES_CACHE) / "diffusers_modules/__init__.py"
spec = importlib.util.spec_from_file_location("diffusers_modules", modpath)
sys.modules["diffusers_modules"] = importlib.util.module_from_spec(spec)Supporto per bitsandbytes
La cosa più fastidiosa del supporto a bitsandbytes è che il processo di quantizzazione avviene su una GPU. Una volta inizializzati CUDA e torch nel server overmind, non c'è un modo semplice per disinizializzarlo, il che può causare problemi per carichi di lavoro reali (principalmente meno VRAM utilizzabile). Pertanto, abbiamo modificato il nostro server per generare un sottoprocesso, caricarlo in memoria condivisa e terminarlo. Questo migliora la stabilità del server overmind.
I parametri quantizzati sono sottoclassi speciali fornite da bitsandbytes. Non sono stati progettati con la 'picklabilità' in mente, quindi dobbiamo farlo noi stessi.
def _reduce_bnb_param(p):
dev = p._prev_device
assert p.quant_state
return (_rebuild_bnb_param, (type(p), p.data, p.quant_state.as_dict(packed=True), dev))
def _rebuild_bnb_param(typ, data, qs_dict, dev):
return typ.from_prequantized(data, qs_dict, device=dev)
def bitsandbytes_quirks():
try:
import bitsandbytes
except ImportError as e:
return
ForkingPickler.register(bitsandbytes.nn.modules.Params4bit, _reduce_bnb_param)
ForkingPickler.register(bitsandbytes.nn.modules.Int8Params, _reduce_bnb_param)I modelli quantizzati tramite bitsandbytes vengono forniti con hook e monkey-patch che non si possono picklare, dobbiamo rimuoverli:
from accelerate.hooks import remove_hook_from_module
remove_hook_from_module(model, True)
model.__dict__.pop('to', None) # Rimuovi i monkeypatch di avviso
model.__dict__.pop('cuda', None)Abbiamo anche riscontrato problemi in cui le funzioni sono annidate all'interno di altre funzioni (anziché essere a livello superiore), il che le rende non picklabili. Abbiamo provato a trovare una soluzione, ma senza successo. Abbiamo dovuto cambiare il nostro pickle da quello fornito dalla libreria standard a dill per picklare questo. dill è molto più potente, ma è un'implementazione puramente Python, che è molto più lenta della versione della libreria standard. Fortunatamente, questo costo verrà pagato solo una volta quando carichiamo il modello per la prima volta (influisce solo sul pickling, non sull'unpickling).
Supporto per stable-fast
stable-fast genera risultati torch.compile, che non possono essere picklati. Ma con torch.jit.save, potremmo salvare i risultati come un file zip. Questo sembra inefficiente, ma è meglio di niente.
Solo con torch.jit.save non è sufficiente per picklare i risultati di stable-fast. stable-fast utilizza un processo di 'flatten' per rendere il modulo Torch tracciabile. Quando incontra qualcosa che non riconosce (ad esempio, la classe di un dataclass), non lo serializzerà, ma manterrà solo un riferimento alla classe effettiva. Abbiamo patchato la logica pertinente per memorizzare effettivamente una classe picklata all'interno dello stream 'flatten'.
def stable_fast_quirks():
...
# pickle dataclass type instead of just put it into a container (which will not survive after torch.jit.save)
def flatten_dataclass(obj):
from sfast.utils.flat_tensors import flatten_bytes, flatten_dict
import dataclasses
d = dict((field.name, getattr(obj, field.name))
for field in dataclasses.fields(obj))
import pickle
pickled = pickle.dumps(obj.__class__)
return flatten_bytes(pickled) + flatten_dict(d)
def unflatten_dataclass(tensors, start):
from sfast.utils.flat_tensors import unflatten_bytes, unflatten_dict
import pickle
pickled, start = unflatten_bytes(tensors, start)
clz = pickle.loads(pickled)
content, start = unflatten_dict(tensors, start)
return clz(**content), start
sfast.utils.flat_tensors.flatten_dataclass = flatten_dataclass
sfast.utils.flat_tensors.unflatten_dataclass = unflatten_dataclassCi sono altri due trucchi qui:
- Ricomprimiamo il file ZIP con
ZIP_STORED, in modo da non dover decomprimere il file ZIP per ogni caricamento successivo. - L'interfaccia
torch.jit.loadcomporta anche il problema della copia della memoria, quindi abbiamo scritto un semplice wrapper per caricarlo tramite il protocollo buffer di Python, proprio comeUntypedStorage.
void initOvermindHelpers(py::module m) {
// ...
m.def("import_ir_module_from_buffer_0copy",
[](std::shared_ptr<torch::jit::CompilationUnit> cu, py::buffer buffer) {
auto info = buffer.request();
imemstream in((char*)info.ptr, info.size); // Nessuna copia!
return import_ir_module(std::move(cu), in, ...);
}
);
}Il pattern vae=vae
Il nostro codice ha qualcosa di simile, tenta di caricare un modello con un modello precedentemente caricato come suo argomento:
import overmind.api
overmind.api.monkey_patch_all()
import torch
from diffusers.models import AutoencoderKL
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
)
vae = AutoencoderKL.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
subfolder="vae",
torch_dtype=torch.float16,
)
controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1p_sd15_depth",
torch_dtype=torch.float16,
variant="fp16",
)
controlnet_edge = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_softedge",
torch_dtype=torch.float16,
variant="fp16",
)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
vae=vae, # Qui!
controlnet=[controlnet_edge, controlnet_depth], # e Qui!
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.to('cuda')Come abbiamo menzionato in precedenza, gli argomenti della funzione sono assunti essere oggetti semplici, facilmente serializzabili, ma questo pattern infrange tale assunzione. Per gestire questo, abbiamo aggiunto una logica speciale: ogni risultato memorizzato nella cache ottiene un ID allegato. Se quell'oggetto viene utilizzato come argomento in un'altra chiamata, il client lo sostituisce con il suo ID, e il server può quindi recuperare l'oggetto effettivo basandosi sull'ID.
Il modello pipeline risultante conterrà un riferimento a vae. Per semplicità, lo serializziamo direttamente qui. Tuttavia, quando si sposta l'effettivo UntypedStorage nella memoria condivisa, deduplichiamo qualsiasi dato ripetuto.
Avremmo potuto usare il meccanismo persistent_id di pickle, ma non ho provato questa strada. È un po' un peccato.
Benchmarking
E ora per la parte che tutti amano vedere.
Usiamo lo script del pattern VAE della sezione precedente per fare il nostro test.
| Test | vae | depth | edge | pipeline | to('cuda') | Totale |
|---|---|---|---|---|---|---|
| senza, 1° | 1.18 | 0.98 | 1.41 | 1.65 | 0.91 | 6.16 |
| senza, 2° | 1.15 | 0.96 | 0.97 | 1.65 | 0.89 | 5.66 |
| senza, 3° | 1.15 | 0.96 | 0.98 | 1.61 | 0.91 | 5.65 |
| s/n, 4° | 1.42 | 1.10 | 1.11 | 1.72 | 0.88 | 6.27 |
| s/n, 5° | 1.28 | 1.08 | 1.10 | 1.72 | 0.92 | 6.13 |
| c/n, 1° | 5.44 | 5.17 | 5.41 | 7.29 | 0.86 | 24.20 |
| c/n, 2° | 0.00 | 0.01 | 0.01 | 0.20 | 0.87 | 1.12 |
| c/n, 3° | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.12 |
| c/n, 4° | 0.01 | 0.01 | 0.01 | 0.20 | 0.90 | 1.15 |
| c/n, 5° | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.13 |
Come puoi vedere, il caricamento iniziale con overmind richiede 24,2 secondi, che è significativamente più lungo rispetto al caricamento senza di esso. Tuttavia, nei caricamenti successivi, solo il costo di .to('cuda') è ancora presente.
Sommando le dimensioni di tutti i file modello serializzati, si stima che l'intera pipeline utilizzi circa 5808 megabyte di memoria. Un rapido benchmark dà un risultato simile.
In [1]: t = torch.ones((5808, 1024, 1024), dtype=torch.uint8)
In [2]: %time a = t.cuda()
CPU times: user 976 ms, sys: 874 μs, total: 977 ms
Wall time: 976 ms
In [3]: %timeit a = t.cuda()
1.01 s ± 56.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)Testato su Intel i9-11900K + GeForce RTX 4090.
Effetti Collaterali Inaspettati (Positivi!)
La nostra motivazione principale per costruire overmind era abilitare un rapido cambio dei pesi del modello durante l'inferenza. Mentre ha servito il suo scopo, abbiamo scoperto diversi vantaggi aggiuntivi lungo il percorso.
Distribuiamo più istanze della nostra applicazione, una per ogni GPU. Pertanto, ci saranno 8 processi per nodo. Dopo aver distribuito overmind, l'uso della memoria di sistema è stato ridotto drasticamente. Non stavamo soffrendo di carenza di memoria di sistema, ma se lo fossimo stati, questo sarebbe stato un grande vantaggio.
Successivamente, abbiamo scoperto che è stato un grande impulso per i nostri sviluppatori di algoritmi e pipeline. Per ogni ciclo di modifica-verifica, potremmo risparmiare dai 10 ai 20 secondi di tempo di caricamento, questo potrebbe sommarsi a un numero enorme. Ancora più importante, i secondi risparmiati potrebbero mantenere gli sviluppatori nel flusso.
Github
Lo stiamo rendendo open-source su Github, saremo felici se potrà essere d'aiuto.


