TL;DR La carga de modelos de ML es lenta, incluso con la caché de páginas de Linux precalentada. Así que construimos una biblioteca para hacerla rápida. Hay algunos detalles técnicos interesantes que queremos compartir, por lo que escribimos este blog. La biblioteca también tuvo un impacto inesperado, discutido al final.
Justificación
Todo comienza hace 2 años, cuando lanzamos nuestro primer intento de modo de generación lowpoly. El modo lowpoly no fue bien, emite resultados pobres desde la perspectiva actual, pero pagamos mucho por ello: una GPU dedicada solo procesa tareas de un solo dígito por día. Tiene pesos ajustados, lo suficientemente grandes como para expulsar todos los demás pesos del modelo de la VRAM. Peor aún, tenemos tal vez 3 modelos de este tipo (no puedo recordar el número exacto), constituyeron una parte significativa de nuestra infraestructura de inferencia, haciendo una proporción de eficiencia bastante implacable. Y no, no podemos cargar los modelos ingenuamente justo a tiempo, cuesta 30s, más que el tiempo de procesamiento real.
No teníamos ingenieros de pipeline dedicados entonces, nuestros desarrolladores de algoritmos hicieron lo mejor para encontrar una solución. Días después, nuestra base de código estaba llena de this.to('cpu') y that.to('cuda'). Este enfoque funciona por un tiempo, pero rompe el flujo de nuestros desarrolladores de algoritmos de vez en cuando. ¿Y si las cosas pudieran suceder automáticamente? Es Python, las cosas suceden automáticamente en Python.
¿Cómo defines 'automáticamente'?
Sumerjámonos en el papel de un desarrollador de algoritmos. Las cosas están bastante claras: no quiero preocuparme por el rendimiento en tiempo de ejecución fuera de mi algoritmo principal a menos que absolutamente tenga que hacerlo. Preferiría no saber nada sobre el intercambio de modelos dentro y fuera.
Por supuesto, no podemos lograr eso, pero podemos intentar minimizar la intrusión que tenemos que introducir en el código del algoritmo. Esto me recuerda al monkey-patching de la biblioteca gevent, que parchea (principalmente) la biblioteca socket, reemplazándola con gevent.socket que puede cambiar a otros greenlets cuando el IO bloquearía, muy parecido a una goroutine (¡en realidad gevent es más antiguo que Golang!).
Dado que solo estábamos usando las bibliotecas de HuggingFace (transformers, diffusers) para cargar modelos en ese momento, el objetivo se volvió claro: Solo introducimos una llamada de monkey-patch, y el resto del código debería permanecer sin cambios, XXXPipeline.from_pretrained(...) debería ser mucho más rápido.
Algunos Hechos, Decisiones Obvias y Suposiciones
Overmind es una biblioteca de caché, almacena en caché los resultados de las llamadas de carga de modelos en la memoria del sistema y luego los reconstruye rápidamente.
Omitimos la discusión sobre cómo se implementa el monkey-patching, ese es un detalle no tan interesante. Todo lo que necesitamos saber es que redirige todas las llamadas XXXPipeline.from_pretrained(...) a overmind.api.load(XXXPipeline.from_pretrained, ...).
Usamos pickle para serializar nuestro resultado de caché ya que... no tenemos otra opción, y torch.save en sí mismo usa pickle, sería raro no usarlo.
Usamos una arquitectura cliente/servidor ya que no queremos invalidar nuestra caché cuando el proceso termina. Hay muchas llamadas a subprocesos que podrían beneficiarse de ello.
Asumimos que los parámetros de XXXPipeline.from_pretrained son cosas simples que se pueden hash (str y cosas similares) y otros modelos cargados por overmind (explicado más adelante).
El nombre overmind se toma prestado de Starcraft, como habrás adivinado.
¡Reconstrúyelo rápido!
No podemos guardar ingenuamente el resultado de pickle.loads en memoria y darlo por terminado. Después de todo, en un escenario precalentado, la caché de páginas de Linux hizo su trabajo almacenando en caché los modelos en disco y aún podemos ver un tiempo de carga medido en decenas de segundos.
La ineficiencia proviene de la copia de memoria. En Python, incluso crear millones de objetos no costaría más de varios cientos de ms. Sin embargo, para una copia de memoria de 10GiB, costaría medio segundo. Debemos evitar la copia de memoria tanto como sea posible.
Afortunadamente, la mayoría de los grandes bloques de memoria son tensores de Torch, podemos abordarlos de manera segura y ignorar el resto.
De hecho, obtuve el conocimiento de la estructura interna de un tensor de Torch en el código de reducción mientras investigaba el mecanismo de compartición de tensores:
# Copiado de torch.multiprocessing.reductions, la mayor parte del código se ha eliminado
def reduce_tensor(tensor):
...
storage = tensor._typed_storage()
...
metadata = (
tensor.storage_offset(),
tensor.size(),
tensor.stride(),
tensor.requires_grad,
)
return (rebuild_tensor, (type(tensor), storage, metadata))Bastante simple: un tensor es su tipo, sus metadatos y su almacenamiento subyacente. Aquí storage es de tipo TypedStorage, pero en realidad TypedStorage es solo un contenedor simple para UntypedStorage. UntypedStorage es la clase que realmente contiene todos los datos del tensor.
Nuestra tarea se vuelve más específica ahora: ¿Cómo evitamos copiar UntypedStorage? ¿Podemos gestionar esta memoria de tensor por nosotros mismos y construir UntypedStorages apuntando a la memoria que gestionamos?
¡La respuesta es sí!
Echando un vistazo al código C++ donde se construye UntypedStorage, podemos encontrar fácilmente un fragmento de código como este:
// Copiado de torch/csrc/Storage.cpp
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
// ...omitiendo código no relacionado...
auto new_storage_impl = make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
slicelength,
at::DataPtr(
static_cast<void*>(data + start),
old_storage_impl,
[](void* s) {
c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
},
old_storage_impl->device()),
old_storage_impl->allocator(),
/* resizable */ false,
device_opt);
PyObject* _ret =
THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl));
return _ret;
}No solo podemos usar un puntero, sino que la clase at::DataPtr también puede manejar la destrucción, haciendo que la gestión del ciclo de vida sea mucho más sencilla.
En el lado de Python, un puntero a una región de memoria está representado por un objeto memoryview, estos objetos soportan el protocolo de buffer. Podemos obtener un objeto memoryview de muchas cosas, bytes y mmap son las 2 principales cosas que lo soportan, y son también lo que nos interesa.
Finalmente, sabemos lo que debemos hacer: crear una función que acepte un objeto memoryview y lo convierta en un UntypedStorage sin copiar. Con la capacidad de reconstruir UntypedStorage desde memoryview, los datos reales del tensor no tienen que estar en el flujo de pickle, lo que reduce enormemente el tamaño de los datos que tenemos que copiar.
void initOvermindHelpers(py::module m) {
// ...
m.def("_make_untyped_storage", [](py::buffer b) {
auto info = new py::buffer_info(b.request());
return pybind11::reinterpret_steal<py::object>(THPStorage_NewWithStorage(
THPStorageClass,
c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
info->size,
at::DataPtr(
info->ptr, info,
[](void* ptr) {
py::gil_scoped_acquire gil;
auto b = static_cast<py::buffer_info*>(ptr);
delete b;
},
at::DeviceType::CPU
),
/*allocator=*/nullptr,
/*resizable=*/false,
)
));
});
}Ese es el bloque de construcción principal de overmind.
¡Compartiendo los tensores!
Nota: Ya existe un mecanismo de compartición de tensores en PyTorch, pero no se ajusta a nuestras necesidades. Más sobre esto más adelante.
Primero, compartiendo memoria entre cliente y servidor
Cuando vemos 'compartir' y 'memoria' juntos, todos tenemos la tentación de usar shmget y sus amigos. Está "diseñado" para ser utilizado como un mecanismo de compartición de memoria, ¿verdad? Pero tiene 2 grandes defectos:
- El shm POSIX es un recurso escaso, lo que puedes usar está determinado por cómo el administrador del sistema configura el sistema. Un ejemplo extremo pero ubicuo son los contenedores Docker, por defecto solo tienes 64MiB de shm POSIX utilizable.
- La memoria compartida POSIX vive más tiempo que tu proceso, tienes que hacer tu propia gestión. Si el proceso de gestión es forzado a terminar, o no lo maneja cuidadosamente, el objeto shm podría quedar en el sistema indefinidamente.
Si miras con atención, Linux está lleno de llamadas al sistema interesantes. memfd_create es una que nos interesa: Te da un fd que representa una asignación de memoria anónima. Puedes realizar todo tipo de operaciones de archivo sobre él: leer, escribir y, por supuesto, mmap. Si podemos compartir el fd, podemos compartir la memoria.
Compartir un fd tiene una manera 'estándar' pero arcana de hacerlo: sendmsg con SCM_RIGHTS. Podemos aprovechar bibliotecas para ayudarnos a ocultar los detalles intimidantes del proceso sendmsg, pero aún tenemos que coordinar entre los procesos del servidor y el cliente. Decidimos usar un truco aquí: Simplemente abrir /proc/{pidof(server)}/fd/{memfd} en el lado del cliente, mientras nunca cerramos el fd en el lado del servidor overmind. La única comunicación necesaria es una tupla (pid, fd). Funciona perfectamente en nuestro caso.
Las palabras anteriores se resumen en estas líneas:
class SharedMemory:
@classmethod
def create(cls, shift):
# Llamado en el lado del servidor
libc = ctypes.CDLL(None)
name = _make_filename(shift).encode('utf-8')
fd = libc.memfd_create(name, os.O_RDWR)
os.ftruncate(fd, 1 << shift)
mem_id = (os.getpid(), fd)
return cls(fd=fd, mem_id=mem_id)
@classmethod
def rebuild(cls, mem_id):
# Llamado en el lado del cliente
pid, fd = mem_id
local_fd = os.open(f'/proc/{pid}/fd/{fd}', os.O_RDWR)
return cls(fd=local_fd, mem_id=mem_id)
def get_buffer(self):
# Llamado en ambos lados
self._mmap = mmap.mmap(self._fd, size)
self._buf = memoryview(self._mmap)
return self._bufIntegración con pickling
Como discutimos antes, necesitamos modificar el proceso de pickling de UntypedStorage. Similar a lo que se implementó en torch.multiprocessing.reductions, definimos nuestras funciones de reducción personalizadas para pickle:
# Hoarder y borrower son un envoltorio para SharedMemory arriba, contienen
# cosas aburridas como el área de memoria, etc.
def _reduce_storage(storage):
# Llamado por el servidor
device = storage.device
storage = storage.cpu()
# Almacenar contenido en memoria compartida
# El `frag` contiene la información completa necesaria para localizar el contenido.
frag = hoarder.put(storage)
return (_rebuild_storage_on_client, (frag, device))
def _rebuild_storage_on_client(frag, device):
# Llamado por el cliente
mv = borrower.borrow(frag) # Obtener una vista de memoria desde la memoria compartida
storage = _make_untyped_storage(mv) # ¡Cero copias!
if device.type == 'cuda':
return storage.cuda(device.index)
return storage
class OvermindPickler(dill.Pickler):
...
OvermindPickler.register(torch.storage.UntypedStorage, _reduce_storage)Ahora, simples OvermindPickler.dumps y OvermindPickler.loads utilizarán memoria compartida para acelerar. Puedes dejar de leer aquí si ya estás harto. El resto son detalles.
Los detalles del diablo
¿Por qué no el método de compartir tensores de PyTorch?
Por 'método de compartir tensores', me refiero a torch.multiprocessing.reductions.
- A un nivel alto, el método de PyTorch está diseñado para 'pasar tensores a un subproceso', parece lo mismo pero con una diferencia sutil.
- PyTorch usa shm POSIX para compartir memoria, sujeto al límite mencionado anteriormente.
- Para cada tensor (o
UntypedStorage), PyTorch asigna un objeto shm POSIX dedicado para él, incluso si contiene solo 4 bytes. Cada objeto consume un fd. - PyTorch desasigna el shm POSIX una vez que se desempaquetan, lo que lo hace inadecuado para nuestras necesidades. Necesitamos deserializar el mismo flujo de pickle múltiples veces.
- Hay mucha lógica relacionada con el intercambio de CUDA, que es puro ruido y problema para nuestro caso de uso.
¿Por qué dices que 'los datos del tensor se copian múltiples veces'?
Para un torch.load típico en disco:
- El archivo
torch.saveen disco se lee en memoria. - Obtén los datos reales de
torch.UntypedStoragecomobytesmediante la extracción de archivos Zip (ya quetorch.savegenera un archivo zip). - El código en C++ copiará los datos en su propia memoria gestionada en el constructor de
torch.UntypedStorage.
Para un pickle.dumps ingenuo y luego pickle.loads:
- El flujo de pickle generado internamente incorpora otro flujo de pickle,
pickle.loadscopiará el flujo interno en un nuevobytes. - Los datos de
torch.UntypedStoragese incorporan en el flujo de pickle interno, otra copia ocurre en la construcción detorch.UntypedStorage. - El código en C++ copiará los datos en su propia memoria gestionada en el constructor de
torch.UntypedStorage.
diffusers tienen un módulo dinámico
Los repositorios de modelos pueden incluir archivos Python que se importan en tiempo de ejecución en un espacio de nombres diffusers_modules. El cliente no tiene estos en sys.path, lo que rompe la deserialización. Afortunadamente, diffusers escribirá estos archivos Python dinámicos en el disco, por lo que simplemente podemos importar el módulo y listo.
def diffusers_dyn_module_workaround():
from diffusers.utils.constants import HF_MODULES_CACHE
modpath = Path(HF_MODULES_CACHE) / "diffusers_modules/__init__.py"
spec = importlib.util.spec_from_file_location("diffusers_modules", modpath)
sys.modules["diffusers_modules"] = importlib.util.module_from_spec(spec)Soporte para bitsandbytes
Lo más molesto de soportar bitsandbytes es que el proceso de cuantización ocurre en una GPU. Una vez que inicializamos CUDA y torch en el servidor overmind, no hay una manera fácil de desinicializarlo, lo que puede causar problemas para cargas de trabajo reales (principalmente menos VRAM utilizable). Por lo tanto, modificamos nuestro servidor para generar un subproceso, cargarlo en memoria compartida y terminarlo. Esto mejora la estabilidad del servidor overmind.
Los parámetros cuantizados son subclases especiales proporcionadas por bitsandbytes. No fueron diseñados pensando en la 'capacidad de ser serializados', por lo que tenemos que hacerlo nosotros mismos.
def _reduce_bnb_param(p):
dev = p._prev_device
assert p.quant_state
return (_rebuild_bnb_param, (type(p), p.data, p.quant_state.as_dict(packed=True), dev))
def _rebuild_bnb_param(typ, data, qs_dict, dev):
return typ.from_prequantized(data, qs_dict, device=dev)
def bitsandbytes_quirks():
try:
import bitsandbytes
except ImportError as e:
return
ForkingPickler.register(bitsandbytes.nn.modules.Params4bit, _reduce_bnb_param)
ForkingPickler.register(bitsandbytes.nn.modules.Int8Params, _reduce_bnb_param)Los modelos cuantizados a través de bitsandbytes vienen con hooks y parches que no se pueden serializar, debemos eliminarlos:
from accelerate.hooks import remove_hook_from_module
remove_hook_from_module(model, True)
model.__dict__.pop('to', None) # Eliminar parches de advertencia
model.__dict__.pop('cuda', None)También hemos encontrado problemas donde las funciones están anidadas dentro de otras funciones (en lugar de estar en el nivel superior), lo que las hace no serializables. Intentamos solucionar esto, pero sin suerte. Tuvimos que cambiar nuestro pickle del proporcionado por la biblioteca estándar a dill para serializar esto. dill es mucho más potente, pero es una implementación pura de Python, lo que es mucho más lento que la versión de la biblioteca estándar. Afortunadamente, este costo solo se pagará una vez cuando carguemos el modelo por primera vez (solo afecta la serialización, no la deserialización).
Soporte para stable-fast
stable-fast genera resultados de torch.compile, que no se pueden serializar. Pero con torch.jit.save, podríamos guardar los resultados como un archivo zip. Esto suena ineficiente, pero es mejor que nada.
Solo con torch.jit.save no es suficiente para serializar los resultados de stable-fast. stable-fast utiliza un proceso de 'aplanamiento' para hacer que el módulo Torch sea trazable. Al encontrar algo que no reconoce (por ejemplo, la clase de dataclass), no lo serializará, sino que solo mantendrá una referencia a la clase real. Hemos parcheado la lógica relevante para almacenar realmente una clase serializada dentro del flujo 'aplanado'.
def stable_fast_quirks():
...
# pickle dataclass type instead of just put it into a container (which will not survive after torch.jit.save)
def flatten_dataclass(obj):
from sfast.utils.flat_tensors import flatten_bytes, flatten_dict
import dataclasses
d = dict((field.name, getattr(obj, field.name))
for field in dataclasses.fields(obj))
import pickle
pickled = pickle.dumps(obj.__class__)
return flatten_bytes(pickled) + flatten_dict(d)
def unflatten_dataclass(tensors, start):
from sfast.utils.flat_tensors import unflatten_bytes, unflatten_dict
import pickle
pickled, start = unflatten_bytes(tensors, start)
clz = pickle.loads(pickled)
content, start = unflatten_dict(tensors, start)
return clz(**content), start
sfast.utils.flat_tensors.flatten_dataclass = flatten_dataclass
sfast.utils.flat_tensors.unflatten_dataclass = unflatten_dataclassHay dos trucos más aquí:
- Reempaquetamos el archivo ZIP con
ZIP_STORED, de modo que no tengamos que descomprimir el archivo ZIP para cada carga subsiguiente. - La interfaz
torch.jit.loadtambién incurre en el problema de copia de memoria, por lo que escribimos un simple contenedor para cargarlo a través del protocolo de buffer de Python, al igual queUntypedStorage.
void initOvermindHelpers(py::module m) {
// ...
m.def("import_ir_module_from_buffer_0copy",
[](std::shared_ptr<torch::jit::CompilationUnit> cu, py::buffer buffer) {
auto info = buffer.request();
imemstream in((char*)info.ptr, info.size); // ¡Sin copia!
return import_ir_module(std::move(cu), in, ...);
}
);
}El patrón vae=vae
Nuestra base de código tiene algo como esto, intenta cargar un modelo con un modelo previamente cargado como su argumento:
import overmind.api
overmind.api.monkey_patch_all()
import torch
from diffusers.models import AutoencoderKL
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
)
vae = AutoencoderKL.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
subfolder="vae",
torch_dtype=torch.float16,
)
controlnet_depth = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1p_sd15_depth",
torch_dtype=torch.float16,
variant="fp16",
)
controlnet_edge = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_softedge",
torch_dtype=torch.float16,
variant="fp16",
)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"lemon2431/ChineseInkComicStrip_v10",
vae=vae, # ¡Aquí!
controlnet=[controlnet_edge, controlnet_depth], # ¡y Aquí!
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.to('cuda')Como mencionamos anteriormente, se supone que los argumentos de función son objetos simples y fácilmente serializables, pero este patrón rompe esa suposición. Para manejar esto, añadimos una lógica especial: cada resultado en caché obtiene un ID adjunto. Si ese objeto se usa como argumento en otra llamada, el cliente lo reemplaza con su ID, y el servidor puede entonces recuperar el objeto real basado en el ID.
El modelo pipeline resultante contendrá una referencia a vae. Para simplificar, simplemente lo serializamos directamente aquí. Sin embargo, al mover el UntypedStorage real a la memoria compartida, deduplicamos cualquier dato repetido.
Podríamos haber usado el mecanismo persistent_id de pickle, pero no intenté esta ruta. Es una lástima.
Benchmarking
Y ahora la parte que todos aman ver.
Usamos el script del patrón VAE de la última sección para hacer nuestra prueba.
| Prueba | vae | depth | edge | pipeline | to('cuda') | Total |
|---|---|---|---|---|---|---|
| sin, 1ra | 1.18 | 0.98 | 1.41 | 1.65 | 0.91 | 6.16 |
| sin, 2da | 1.15 | 0.96 | 0.97 | 1.65 | 0.89 | 5.66 |
| sin, 3ra | 1.15 | 0.96 | 0.98 | 1.61 | 0.91 | 5.65 |
| w/o, 4th | 1.42 | 1.10 | 1.11 | 1.72 | 0.88 | 6.27 |
| w/o, 5th | 1.28 | 1.08 | 1.10 | 1.72 | 0.92 | 6.13 |
| w/, 1st | 5.44 | 5.17 | 5.41 | 7.29 | 0.86 | 24.20 |
| w/, 2nd | 0.00 | 0.01 | 0.01 | 0.20 | 0.87 | 1.12 |
| w/, 3rd | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.12 |
| w/, 4th | 0.01 | 0.01 | 0.01 | 0.20 | 0.90 | 1.15 |
| w/, 5th | 0.01 | 0.01 | 0.01 | 0.21 | 0.86 | 1.13 |
Como puedes ver, la carga inicial con overmind toma 24.2 segundos, lo cual es significativamente más largo en comparación con la carga sin él. Sin embargo, en cargas subsecuentes, solo el costo de .to('cuda') sigue presente.
Sumando los tamaños de todos los archivos del modelo serializados, se estima que todo el pipeline usa alrededor de 5808 megabytes de memoria. Un benchmark rápido da un resultado similar.
In [1]: t = torch.ones((5808, 1024, 1024), dtype=torch.uint8)
In [2]: %time a = t.cuda()
CPU times: user 976 ms, sys: 874 μs, total: 977 ms
Wall time: 976 ms
In [3]: %timeit a = t.cuda()
1.01 s ± 56.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)Probado en Intel i9-11900K + GeForce RTX 4090.
Efectos Secundarios Inesperados (¡Positivos!)
Nuestra motivación principal para construir overmind fue permitir el cambio rápido de pesos del modelo durante la inferencia. Aunque cumplió su propósito, descubrimos varias ventajas adicionales en el camino.
Desplegamos múltiples instancias de nuestra aplicación, una para cada GPU. Por lo tanto, habrá 8 procesos por nodo. Después de desplegar overmind, el uso de memoria del sistema se redujo dramáticamente. No estábamos sufriendo de escasez de memoria del sistema, pero si lo hubiéramos estado, esto habría sido una gran ventaja.
Más tarde, encontramos que fue un gran impulso para nuestros desarrolladores de algoritmos y pipelines. Por cada ciclo de modificar-verificar, podríamos ahorrar de 10 a 20 segundos de tiempo de carga, esto podría sumar un gran número. Más importante aún, los segundos ahorrados podrían mantener a los desarrolladores en el flujo.
Github
Lo estamos haciendo de código abierto en Github, estaremos felices si ayuda.


