TL;DR Le chargement des modèles ML est lent, même avec le cache de pages Linux préchauffé. Nous avons donc créé une bibliothèque pour le rendre rapide. Il y a quelques détails techniques intéressants que nous voulons partager, alors nous avons écrit ce blog. La bibliothèque a également eu un impact inattendu, discuté à la fin.
Raisonnement
Tout a commencé il y a 2 ans, lorsque nous avons lancé notre première tentative de mode de génération lowpoly. Le mode lowpoly ne s'est pas bien passé, il produit de mauvais résultats du point de vue actuel, mais nous avons payé cher pour cela -- un GPU dédié ne traite que quelques tâches par jour. Il a des poids ajustés, suffisamment grands pour évincer tous les autres poids de modèle de la VRAM. Pire encore, nous avons peut-être 3 modèles de ce type (je ne me souviens plus du nombre exact), ils constituaient une partie significative de notre infrastructure d'inférence, avec un ratio d'efficacité assez impitoyable. Et non, nous ne pouvons pas charger naïvement les modèles juste à temps, cela coûte 30s, plus que le temps de traitement réel.
Nous n'avions pas d'ingénieurs de pipeline dédiés à l'époque, nos développeurs d'algorithmes ont fait de leur mieux pour contourner ce problème. Quelques jours plus tard, notre base de code était jonchée de this.to('cpu') et that.to('cuda'). Cette approche fonctionne pendant un certain temps, mais perturbe le flux de nos développeurs d'algo de temps en temps. Et si les choses pouvaient se passer automatiquement ? C'est Python, les choses se passent automatiquement en Python.
Comment définissez-vous 'automatiquement' ?
Passons au rôle d'un développeur d'algorithmes. Les choses sont assez claires : je ne veux pas me soucier des performances d'exécution en dehors de mon algorithme principal à moins d'y être absolument obligé. Je préférerais ne rien savoir sur le swap de modèle.
Bien sûr, nous ne pouvons pas atteindre cela, mais nous pouvons essayer de minimiser l'intrusion que nous devons introduire dans le code de l'algorithme. Cela me rappelle le monkey-patching de la bibliothèque gevent, qui modifie (principalement) la bibliothèque socket, la remplaçant par gevent.socket qui peut basculer vers d'autres greenlets lorsque l'IO bloquerait, un peu comme une goroutine (en fait, gevent est plus ancien que Golang !).
Puisque nous n'utilisions que les bibliothèques HuggingFace (transformers, diffusers) pour charger les modèles à l'époque, l'objectif est devenu clair : Nous n'introduisons qu'un appel de monkey-patch, et le reste du code devrait rester inchangé, XXXPipeline.from_pretrained(...) devrait être beaucoup plus rapide.
Quelques Faits, Décisions Évidentes et Hypothèses
Overmind est une bibliothèque de mise en cache, elle met en cache les résultats des appels de chargement de modèles en mémoire système et les reconstruit rapidement par la suite.
Nous passons outre la discussion sur la façon dont le monkey-patching est implémenté, ce n'est pas un détail très intéressant. Tout ce que nous devons savoir, c'est qu'il redirige tous les appels XXXPipeline.from_pretrained(...) vers overmind.api.load(XXXPipeline.from_pretrained, ...).
Nous utilisons pickle pour sérialiser notre résultat de cache car... nous n'avons pas le choix, et torch.save utilise lui-même pickle, il serait étrange de ne pas l'utiliser.
Nous utilisons une architecture client/serveur car nous ne voulons pas invalider notre cache lorsque le processus se termine. Il y a de nombreux appels de sous-processus qui pourraient en bénéficier.
Nous supposons que les paramètres XXXPipeline.from_pretrained sont des choses simples et hachables (str et similaires) et d'autres modèles chargés par overmind (expliqué plus tard).
Le nom overmind est emprunté à Starcraft, comme vous l'avez peut-être deviné.
Reconstruisez-le rapidement !
Nous ne pouvons pas naïvement enregistrer le résultat de pickle.loads en mémoire et en rester là. Après tout, dans un scénario préchauffé, le cache de pages Linux a fait son travail en mettant en cache les modèles sur disque et nous pouvons toujours voir un temps de chargement mesuré en dizaines de secondes.
L'inefficacité vient de la copie de mémoire. En Python, même créer des millions d'objets ne coûterait pas plus de quelques centaines de ms. Cependant, pour une copie de mémoire de 10 Go, cela coûterait une demi-seconde. Nous devons éviter la copie de mémoire autant que possible.
Heureusement, la plupart des gros blocs de mémoire sont des tenseurs Torch, nous pouvons les adresser en toute sécurité et ignorer le reste.
En fait, j'ai acquis la connaissance de la structure interne d'un tenseur Torch dans le code de réduction en recherchant le mécanisme de partage de tenseurs :
# Copié de torch.multiprocessing.reductions, la plupart du code est supprimé
def reduce_tensor(tensor):
...
storage = tensor._typed_storage()
...
metadata = (
tensor.storage_offset(),
tensor.size(),
tensor.stride(),
tensor.requires_grad,
)
return (rebuild_tensor, (type(tensor), storage, metadata))Assez simple : un tenseur est son type, ses métadonnées et son stockage sous-jacent. Ici, storage est de type TypedStorage, mais en réalité TypedStorage n'est qu'un simple wrapper pour UntypedStorage. UntypedStorage est la classe qui contient réellement toutes les données du tenseur.
Notre tâche devient plus spécifique maintenant : comment éviter de copier UntypedStorage ? Pouvons-nous gérer nous-mêmes la mémoire de ces tenseurs et construire des UntypedStorage en pointant vers la mémoire que nous gérons ?
La réponse est oui !
En parcourant le code C++ où UntypedStorage est construit, nous pouvons facilement trouver un extrait de code comme celui-ci :
// Copié de torch/csrc/Storage.cpp
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
// ...omission de code non pertinent...
auto new_storage_impl = make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
slicelength,
at::DataPtr(
static_cast<void*>(data + start),
old_storage_impl,
[](void* s) {
c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
},
old_storage_impl->device()),
old_storage_impl->allocator(),
/* resizable */ false,
device_opt);
PyObject* _ret =
THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl));
return _ret;
}Non seulement nous pouvons utiliser un pointeur, mais la classe at::DataPtr peut également gérer la destruction, simplifiant ainsi la gestion de la durée de vie.
Du côté Python, un pointeur vers une région mémoire est représenté par un objet memoryview, ces objets supportent le protocole de buffer. Nous pouvons obtenir un objet memoryview à partir de nombreuses choses, bytes et mmap sont les 2 principales choses qui le supportent, et ce sont également celles qui nous intéressent.
Enfin, nous savons ce que nous devons faire : créer une fonction qui accepte un objet memoryview et le transforme en un UntypedStorage sans copie. Avec la capacité de reconstruire UntypedStorage à partir de memoryview, les données réelles du tenseur n'ont pas besoin d'être dans le flux pickle, réduisant considérablement la taille des données que nous devons copier.
void initOvermindHelpers(py::module m) {
// ...
m.def("_make_untyped_storage", [](py::buffer b) {
auto info = new py::buffer_info(b.request());
return pybind11::reinterpret_steal<py::object>(THPStorage_NewWithStorage(
THPStorageClass,
c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
info->size,
at::DataPtr(
info->ptr, info,
[](void* ptr) {
py::gil_scoped_acquire gil;
auto b = static_cast<py::buffer_info*>(ptr);
delete b;
},
at::DeviceType::CPU
),
/*allocator=*/nullptr,
/*resizable=*/false,
)
));
});
}C'est le bloc de construction principal de overmind.
Partage des tenseurs !
Remarque : Il existe déjà un mécanisme de partage de tenseurs dans PyTorch, mais il ne correspond pas à nos besoins. Plus d'informations à ce sujet plus tard.
Tout d'abord, partage de la mémoire entre client et serveur
Lorsque nous voyons 'partager' et 'mémoire' ensemble, nous avons tous une envie d'utiliser shmget et ses amis. Il est "conçu" pour être utilisé comme un mécanisme de partage de mémoire, n'est-ce pas ? Mais il a 2 défauts majeurs :
- La mémoire partagée POSIX est une ressource rare, ce que vous pouvez utiliser est déterminé par la façon dont l'administrateur système configure le système. Un exemple extrême mais omniprésent est les conteneurs Docker, par défaut vous n'avez que 64MiB de mémoire partagée POSIX utilisable.
- Le shm POSIX vit plus longtemps que votre processus, vous devez donc gérer cela vous-même. Si le processus de gestion est tué de force, ou ne le gère pas avec soin, l'objet shm pourrait rester sur le système indéfiniment.
Si vous regardez attentivement, Linux est plein d'appels système intéressants. memfd_create est celui qui nous intéresse : Il vous donne un fd qui représente une allocation de mémoire anonyme. Vous pouvez effectuer toutes sortes d'opérations de fichier dessus : lire, écrire, et bien sûr, mmap. Si nous pouvons partager le fd, nous pouvons partager la mémoire.
Partager un fd a une manière 'standard' mais ésotérique de le faire : sendmsg avec SCM_RIGHTS. Nous pouvons tirer parti des bibliothèques pour nous aider à cacher les détails intimidants du processus sendmsg, mais nous devons toujours faire notre coordination entre les processus serveur et client. Nous avons décidé d'utiliser un hack ici : Il suffit d'ouvrir /proc/{pidof(server)}/fd/{memfd} côté client, tout en ne fermant jamais le fd côté serveur overmind. La seule communication nécessaire est un tuple (pid, fd). Cela fonctionne parfaitement dans notre cas.
Les mots ci-dessus se résument à ces lignes :
class SharedMemory:
@classmethod
def create(cls, shift):
# Appelé côté serveur
libc = ctypes.CDLL(None)
name = _make_filename(shift).encode('utf-8')
fd = libc.memfd_create(name, os.O_RDWR)
os.ftruncate(fd, 1 << shift)
mem_id = (os.getpid(), fd)
return cls(fd=fd, mem_id=mem_id)
@classmethod
def rebuild(cls, mem_id):
# Appelé côté client
pid, fd = mem_id
local_fd = os.open(f'/proc/{pid}/fd/{fd}', os.O_RDWR)
return cls(fd=local_fd, mem_id=mem_id)
def get_buffer(self):
# Appelé des deux côtés
self._mmap = mmap.mmap(self._fd, size)
self._buf = memoryview(self._mmap)
return self._bufIntégration avec le pickling
Comme nous l'avons discuté précédemment, nous devons modifier le processus de pickling de UntypedStorage. Similaire à ce qui a été implémenté dans torch.multiprocessing.reductions, nous définissons nos fonctions de réduction personnalisées pour pickle :
# Hoarder et borrower sont des wrappers pour SharedMemory ci-dessus, contenant
# des éléments ennuyeux comme l'arène de mémoire, etc.
def _reduce_storage(storage):
# Appelé par le serveur
device = storage.device
storage = storage.cpu()
# Stocker le contenu dans la mémoire partagée
# Le `frag` contient toutes les informations nécessaires pour localiser le contenu.
frag = hoarder.put(storage)
return (_rebuild_storage_on_client, (frag, device))
def _rebuild_storage_on_client(frag, device):
# Appelé par le client
mv = borrower.borrow(frag) # Obtenir une vue mémoire de la mémoire partagée
storage = _make_untyped_storage(mv) # Zéro-copie !
si le type de device est 'cuda':
return storage.cuda(device.index)
return storage
class OvermindPickler(dill.Pickler):
...
OvermindPickler.register(torch.storage.UntypedStorage, _reduce_storage)Maintenant, des simples OvermindPickler.dumps et OvermindPickler.loads utiliseront la mémoire partagée pour accélérer. Vous pouvez arrêter de lire ici si vous en avez déjà assez. Le reste sont des détails.
Les détails du diable
Pourquoi ne pas utiliser le partage de tenseurs interne de PyTorch ?
Pour le 'partage de tenseurs interne', je veux dire torch.multiprocessing.reductions.
- À un niveau élevé, la méthode de PyTorch est conçue pour 'passer un tenseur à un sous-processus', ce qui semble similaire mais avec une différence subtile.
- PyTorch utilise le shm POSIX pour partager la mémoire, soumis à la limite mentionnée précédemment.
- Pour chaque tenseur (ou
UntypedStorage), PyTorch alloue un objet shm POSIX dédié, même s'il ne contient que 4 octets. Chaque objet consomme un fd. - PyTorch désalloue le shm POSIX une fois qu'ils sont dé-sérialisés, ce qui le rend inadapté à nos besoins. Nous devons désérialiser le même flux pickle plusieurs fois.
- Il y a beaucoup de logique de partage liée à CUDA, qui sont du bruit pur et des problèmes pour notre cas d'utilisation.
Pourquoi dites-vous que 'les données des tenseurs sont copiées plusieurs fois' ?
Pour un torch.load typique sur disque :
- Le fichier
torch.savesur disque est lu en mémoire. - Obtenez les données réelles de
torch.UntypedStoragesous forme debytespar extraction de fichier Zip (torch.savegénère un fichier zip). - Le code C++ copiera les données dans sa propre mémoire gérée dans le constructeur de
torch.UntypedStorage.
Pour un pickle.dumps naïf et plus tard pickle.loads :
- Le flux pickle généré intègre en interne un autre flux pickle,
pickle.loadscopiera le flux interne dans un nouveaubytes. - Les données de
torch.UntypedStoragesont intégrées dans le flux pickle interne, une autre copie se produit lors de la construction detorch.UntypedStorage. - Le code C++ copiera les données dans sa propre mémoire gérée dans le constructeur de
torch.UntypedStorage.
diffusers ont un module dynamique
Les dépôts de modèles peuvent inclure des fichiers Python qui sont importés à l'exécution dans un espace de noms diffusers_modules. Le client ne les a pas dans sys.path, ce qui casse le déballage. Heureusement, diffusers écrira ces fichiers Python dynamiques sur le disque, donc nous pouvons simplement importer le module et passer à autre chose.
def diffusers_dyn_module_workaround():
from diffusers.utils.constants import HF_MODULES_CACHE
modpath = Path(HF_MODULES_CACHE) / "diffusers_modules/__init__.py"
spec = importlib.util.spec_from_file_location("diffusers_modules", modpath)
sys.modules["diffusers_modules"] = importlib.util.module_from_spec(spec)Support pour bitsandbytes
La chose la plus ennuyeuse à propos du support de bitsandbytes est que le processus de quantification se produit sur un GPU. Une fois que nous avons initialisé CUDA et torch dans le serveur overmind, il n'y a pas de moyen facile de le désinitialiser, ce qui peut causer des problèmes pour les charges de travail réelles (principalement moins de VRAM utilisable). Par conséquent, nous avons modifié notre serveur pour engendrer un sous-processus, le charger en mémoire partagée, et terminer. Cela améliore la stabilité du serveur overmind.
Les paramètres quantifiés sont des sous-classes spéciales fournies par bitsandbytes. Ils n'ont pas été conçus pour être 'picklables', donc nous devons le faire nous-mêmes.
def _reduce_bnb_param(p):
dev = p._prev_device
assert p.quant_state
return (_rebuild_bnb_param, (type(p), p.data, p.quant_state.as_dict(packed=True), dev))
def _rebuild_bnb_param(typ, data, qs_dict, dev):
return typ.from_prequantized(data, qs_dict, device=dev)
def bitsandbytes_quirks():
try:
import bitsandbytes
except ImportError as e:
return
ForkingPickler.register(bitsandbytes.nn.modules.Params4bit, _reduce_bnb_param)
ForkingPickler.register(bitsandbytes.nn.modules.Int8Params, _reduce_bnb_param)Les modèles quantifiés via bitsandbytes viennent avec des hooks et des monkey-patches qui ne se picklent pas, nous devons les retirer :
from accelerate.hooks import remove_hook_from_module
remove_hook_from_module(model, True)
model.__dict__.pop('to', None) # Supprimer les monkeypatches d'avertissement
model.__dict__.pop('cuda', None)Nous avons également rencontré des problèmes où les fonctions sont imbriquées dans d'autres fonctions (plutôt qu'au niveau supérieur), ce qui les rend non picklables. Nous avons essayé de contourner cela, mais sans succès. Nous avons dû passer de notre pickle fourni par la bibliothèque standard à dill pour le pickler. dill est beaucoup plus puissant, mais c'est une implémentation pure Python, ce qui est beaucoup plus lent que la version de la bibliothèque standard. Heureusement, ce coût ne sera payé qu'une seule fois lors du chargement initial du modèle (n'affecte que le pickling, pas le unpickling).
Support pour stable-fast
stable-fast génère des résultats torch.compile, qui ne peuvent pas être picklés. Mais avec torch.jit.save, nous pourrions sauvegarder les résultats sous forme de fichier zip. Cela semble inefficace, mais c'est mieux que rien.
Avec seulement torch.jit.save, il n'est pas suffisant de pickler les résultats de stable-fast. stable-fast utilise un processus de 'flatten' pour rendre le module Torch traçable. Lorsqu'il rencontre quelque chose qu'il ne reconnaît pas (par exemple, la classe dataclass), il ne la sérialisera pas, mais ne gardera qu'une référence à la classe réelle. Nous avons patché la logique pertinente pour réellement stocker une classe picklée dans le flux 'flatten'.
def stable_fast_quirks():
...
# pickle dataclass type instead of just put it into a container (which will not survive after torch.jit.save)
def flatten_dataclass(obj):
from sfast.utils.flat_tensors import flatten_bytes, flatten_dict
import dataclasses
d = dict((field.name, getattr(obj, field.name))
for field in dataclasses.fields(obj))
import pickle
pickled = pickle.dumps(obj.__class__)
return flatten_bytes(pickled) + flatten_dict(d)
def unflatten_dataclass(tensors, start):
from sfast.utils.flat_tensors import unflatten_bytes, unflatten_dict
import pickle
pickled, start = unflatten_bytes(tensors, start)
clz = pickle.loads(pickled)
content, start = unflatten_dict(tensors, start)
return clz(**content), start
sfast.utils.flat_tensors.flatten_dataclass = flatten_dataclass
sfast.utils.flat_tensors.unflatten_dataclass = unflatten_dataclassIl y a deux autres astuces ici :
- Nous reconditionnons le fichier ZIP avec
ZIP_STORED, donc nous n'avons pas besoin de décompresser le fichier ZIP pour chaque chargement ultérieur. - L'interface
torch.jit.loadentraîne également un problème de copie en mémoire, nous avons donc écrit un simple wrapper pour le charger via le protocole de buffer Python, tout commeUntypedStorage.
void initOvermindHelpers(py::module m) {
// ...
m.def("import_ir_module_from_buffer_0copy",
[](std::shared_ptr<torch::jit::CompilationUnit> cu, py::buffer buffer) {
auto info = buffer.request();
imemstream in((char*)info.ptr, info.size); // Pas de copie !
return import_ir_module(std::move(cu), in, ...);
}
);
}Le motif vae=vae
Notre base de code a quelque chose comme cela, elle tente de charger un modèle avec un modèle précédemment chargé comme argument :
import overmind.api
overmind.api.monkey_patch_all()
import torch
from diffusers.models import AutoencoderKL
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
)
vae = AutoencoderKL.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
subfolder="vae",
torch_dtype=torch.float16,
)
controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1p_sd15_depth",
torch_dtype=torch.float16,
variant="fp16",
)
controlnet_edge = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_softedge",
torch_dtype=torch.float16,
variant="fp16",
)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
vae=vae, # Ici !
controlnet=[controlnet_edge, controlnet_depth], # et Ici !
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.to('cuda')Comme nous l'avons mentionné précédemment, les arguments de fonction sont supposés être des objets simples, facilement sérialisables, mais ce motif brise cette hypothèse. Pour gérer cela, nous avons ajouté une logique spéciale : chaque résultat mis en cache reçoit un ID attaché. Si cet objet est utilisé comme argument dans un autre appel, le client le remplace par son ID, et le serveur peut alors récupérer l'objet réel basé sur l'ID.
Le modèle pipeline résultant contiendra une référence à vae. Pour simplifier, nous le sérialisons directement ici. Cependant, lors du déplacement de l'UntypedStorage réel vers la mémoire partagée, nous dédupliquons toutes les données répétées.
Nous aurions pu utiliser le mécanisme persistent_id de pickle, mais je n'ai pas essayé cette voie. C'est un peu dommage.
Benchmarking
Et maintenant pour la partie que tout le monde aime voir.
Nous utilisons le script de motif VAE de la dernière section pour faire notre test.
| Test | vae | depth | edge | pipeline | to('cuda') | Total |
|---|---|---|---|---|---|---|
| w/o, 1st | 1.18 | 0.98 | 1.41 | 1.65 | 0.91 | 6.16 |
| w/o, 2nd | 1.15 | 0.96 | 0.97 | 1.65 | 0.89 | 5.66 |
| w/o, 3rd | 1.15 | 0.96 | 0.98 | 1.61 | 0.91 | 5.65 |
| s/s, 4ème | 1.42 | 1.10 | 1.11 | 1.72 | 0.88 | 6.27 |
| s/s, 5ème | 1.28 | 1.08 | 1.10 | 1.72 | 0.92 | 6.13 |
| avec, 1ère | 5.44 | 5.17 | 5.41 | 7.29 | 0.86 | 24.20 |
| avec, 2ème | 0.00 | 0.01 | 0.01 | 0.20 | 0.87 | 1.12 |
| avec, 3ème | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.12 |
| avec, 4ème | 0.01 | 0.01 | 0.01 | 0.20 | 0.90 | 1.15 |
| avec, 5ème | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.13 |
Comme vous pouvez le voir, le chargement initial avec overmind prend 24,2 secondes, ce qui est nettement plus long par rapport au chargement sans. Cependant, lors des chargements suivants, seul le coût de .to('cuda') est encore présent.
En additionnant les tailles de tous les fichiers de modèle sérialisés, l'ensemble du pipeline est estimé à utiliser environ 5808 mégaoctets de mémoire. Un test rapide donne un résultat similaire.
In [1]: t = torch.ones((5808, 1024, 1024), dtype=torch.uint8)
In [2]: %time a = t.cuda()
CPU times: user 976 ms, sys: 874 μs, total: 977 ms
Wall time: 976 ms
In [3]: %timeit a = t.cuda()
1.01 s ± 56.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)Testé sur Intel i9-11900K + GeForce RTX 4090.
Effets Secondaires Inattendus (Positifs !)
Notre motivation principale pour construire overmind était de permettre un changement rapide des poids du modèle lors de l'inférence. Bien qu'il ait servi son objectif, nous avons découvert plusieurs avantages supplémentaires en cours de route.
Nous déployons plusieurs instances de notre application, une pour chaque GPU. Ainsi, il y aura 8 processus par nœud. Après avoir déployé overmind, l'utilisation de la mémoire système a été réduite de manière spectaculaire. Nous ne souffrions pas d'un manque de mémoire système, mais si cela avait été le cas, cela aurait été un grand avantage.
Plus tard, nous avons constaté que cela constituait un grand coup de pouce pour nos développeurs d'algorithmes et de pipelines. Pour chaque boucle de modification-vérification, nous pouvions économiser 10 à 20 secondes de temps de chargement, ce qui pourrait s'accumuler en un nombre énorme. Plus important encore, les secondes économisées pouvaient maintenir les développeurs dans le flux.
Github
Nous le rendons open-source sur Github, nous serons heureux si cela vous aide.


