mirror of https://github.com/skygpu/skynet.git
Begin adding TUI
parent
93ee65087f
commit
b3dc7c1074
|
@ -61,6 +61,7 @@ cuda = [
|
|||
"basicsr>=1.4.2,<2",
|
||||
"realesrgan>=0.3.0,<0.4",
|
||||
"sentencepiece>=0.2.0",
|
||||
"urwid>=2.6.16",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
|
||||
|
@ -6,11 +7,31 @@ from hypercorn.config import Config
|
|||
from hypercorn.trio import serve
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.compute import ModelMngr
|
||||
from skynet.dgpu.daemon import WorkerDaemon
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
def setup_logging_for_tui(level):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
|
||||
fh = logging.FileHandler('dgpu.log')
|
||||
fh.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
fh.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(fh)
|
||||
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
async def open_dgpu_node(config: dict) -> None:
|
||||
'''
|
||||
Open a top level "GPU mgmt daemon", keep the
|
||||
|
@ -18,13 +39,17 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
and *maybe* serve a `hypercorn` web API.
|
||||
|
||||
'''
|
||||
|
||||
# suppress logs from httpx (logs url + status after every query)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
conn = NetConnector(config)
|
||||
mm = ModelMngr(config)
|
||||
daemon = WorkerDaemon(mm, conn, config)
|
||||
tui = None
|
||||
if config['tui']:
|
||||
setup_logging_for_tui(logging.INFO)
|
||||
tui = WorkerMonitor()
|
||||
|
||||
conn = NetConnector(config, tui=tui)
|
||||
mm = ModelMngr(config, tui=tui)
|
||||
daemon = WorkerDaemon(mm, conn, config, tui=tui)
|
||||
|
||||
api: Quart|None = None
|
||||
if 'api_bind' in config:
|
||||
|
@ -35,6 +60,8 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
tn: trio.Nursery
|
||||
async with trio.open_nursery() as tn:
|
||||
tn.start_soon(daemon.snap_updater_task)
|
||||
if tui:
|
||||
tn.start_soon(tui.run)
|
||||
|
||||
# TODO, consider a more explicit `as hypercorn_serve`
|
||||
# to clarify?
|
||||
|
@ -42,5 +69,9 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
logging.info(f'serving api @ {config["api_bind"]}')
|
||||
tn.start_soon(serve, api, api_conf)
|
||||
|
||||
try:
|
||||
# block until cancelled
|
||||
await daemon.serve_forever()
|
||||
|
||||
except *urwid.ExitMainLoop in ex_group:
|
||||
...
|
||||
|
|
|
@ -11,6 +11,7 @@ from hashlib import sha256
|
|||
import trio
|
||||
import torch
|
||||
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
DGPUInferenceCancelled,
|
||||
|
@ -72,7 +73,8 @@ class ModelMngr:
|
|||
checking load state, and unloading when no-longer-needed/finished.
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||
self._tui = tui
|
||||
self.cache_dir = None
|
||||
if 'hf_home' in config:
|
||||
self.cache_dir = config['hf_home']
|
||||
|
@ -80,8 +82,6 @@ class ModelMngr:
|
|||
self._model_name: str = ''
|
||||
self._model_mode: str = ''
|
||||
|
||||
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
||||
|
||||
def log_debug_info(self):
|
||||
logging.debug('memory summary:')
|
||||
logging.debug('\n' + torch.cuda.memory_summary())
|
||||
|
@ -110,6 +110,7 @@ class ModelMngr:
|
|||
) -> None:
|
||||
logging.info(f'loading model {name}...')
|
||||
self.unload_model()
|
||||
|
||||
self._model = pipeline_for(
|
||||
name, mode, cache_dir=self.cache_dir)
|
||||
self._model_mode = mode
|
||||
|
@ -124,11 +125,19 @@ class ModelMngr:
|
|||
params: dict,
|
||||
inputs: list[bytes] = []
|
||||
):
|
||||
def maybe_cancel_work(step, *args, **kwargs):
|
||||
total_steps = params['step']
|
||||
def inference_step_wakeup(*args, **kwargs):
|
||||
'''This is a callback function that gets invoked every inference step,
|
||||
we need to raise an exception here if we need to cancel work
|
||||
'''
|
||||
if self._should_cancel:
|
||||
step = args[0]
|
||||
# compat with callback_on_step_end
|
||||
if not isinstance(step, int):
|
||||
step = args[1]
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_progress(step, done=total_steps)
|
||||
|
||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||
if should_raise:
|
||||
logging.warning(f'CANCELLING work at step {step}')
|
||||
|
@ -136,7 +145,10 @@ class ModelMngr:
|
|||
|
||||
return {}
|
||||
|
||||
maybe_cancel_work(0)
|
||||
if self._tui:
|
||||
self._tui.set_status(f'Request #{request_id}')
|
||||
|
||||
inference_step_wakeup(0)
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in params:
|
||||
|
@ -157,10 +169,10 @@ class ModelMngr:
|
|||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
|
||||
if 'flux' in name.lower():
|
||||
extra_params['callback_on_step_end'] = maybe_cancel_work
|
||||
extra_params['callback_on_step_end'] = inference_step_wakeup
|
||||
|
||||
else:
|
||||
extra_params['callback'] = maybe_cancel_work
|
||||
extra_params['callback'] = inference_step_wakeup
|
||||
extra_params['callback_steps'] = 1
|
||||
|
||||
output = self._model(
|
||||
|
@ -213,4 +225,7 @@ class ModelMngr:
|
|||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_status('')
|
||||
|
||||
return output_hash, output
|
||||
|
|
|
@ -17,6 +17,7 @@ from skynet.constants import (
|
|||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
)
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.compute import ModelMngr
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
@ -41,10 +42,12 @@ class WorkerDaemon:
|
|||
self,
|
||||
mm: ModelMngr,
|
||||
conn: NetConnector,
|
||||
config: dict
|
||||
config: dict,
|
||||
tui: WorkerMonitor | None = None
|
||||
):
|
||||
self.mm: ModelMngr = mm
|
||||
self.conn: NetConnector = conn
|
||||
self._tui = tui
|
||||
self.auto_withdraw = (
|
||||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
|
@ -150,6 +153,12 @@ class WorkerDaemon:
|
|||
|
||||
return app
|
||||
|
||||
async def _update_balance(self):
|
||||
if self._tui:
|
||||
# update balance
|
||||
balance = await self.conn.get_worker_balance()
|
||||
self._tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
# TODO? this func is kinda big and maybe is better at module
|
||||
# level to reduce indentation?
|
||||
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
|
||||
|
@ -238,6 +247,8 @@ class WorkerDaemon:
|
|||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
total_step = body['params']['step']
|
||||
|
||||
# TODO: validate request
|
||||
|
||||
resp = await self.conn.begin_work(rid)
|
||||
|
@ -246,6 +257,9 @@ class WorkerDaemon:
|
|||
|
||||
else:
|
||||
try:
|
||||
if self._tui:
|
||||
self._tui.set_progress(0, done=total_step)
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
output_type = body['params']['output_type']
|
||||
|
@ -269,6 +283,9 @@ class WorkerDaemon:
|
|||
f'Unsupported backend {self.backend}'
|
||||
)
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_progress(total_step)
|
||||
|
||||
self._last_generation_ts: str = datetime.now().isoformat()
|
||||
self._last_benchmark: list[float] = self._benchmark
|
||||
self._benchmark: list[float] = []
|
||||
|
@ -277,6 +294,9 @@ class WorkerDaemon:
|
|||
|
||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
await self._update_balance()
|
||||
|
||||
|
||||
except BaseException as err:
|
||||
if 'network cancel' not in str(err):
|
||||
logging.exception('Failed to serve model request !?\n')
|
||||
|
@ -294,6 +314,7 @@ class WorkerDaemon:
|
|||
# -[ ] keeps tasks-as-funcs style prominent
|
||||
# -[ ] avoids so much indentation due to methods
|
||||
async def serve_forever(self):
|
||||
await self._update_balance()
|
||||
try:
|
||||
while True:
|
||||
if self.auto_withdraw:
|
||||
|
|
|
@ -13,6 +13,7 @@ import outcome
|
|||
from PIL import Image
|
||||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Asset
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.constants import (
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
GPU_CONTRACT_ABI,
|
||||
|
@ -57,7 +58,7 @@ class NetConnector:
|
|||
- CLEOS client
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||
# TODO, why these extra instance vars for an (unsynced)
|
||||
# copy of the `config` state?
|
||||
self.account = config['account']
|
||||
|
@ -81,6 +82,10 @@ class NetConnector:
|
|||
self.ipfs_domain = config['ipfs_domain']
|
||||
|
||||
self._wip_requests = {}
|
||||
self._tui = tui
|
||||
if self._tui:
|
||||
self._tui.set_header_text(new_worker_name=self.account)
|
||||
|
||||
|
||||
# blockchain helpers
|
||||
|
||||
|
@ -163,6 +168,9 @@ class NetConnector:
|
|||
n.start_soon(
|
||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||
|
||||
if self._tui:
|
||||
self._tui.network_update(snap)
|
||||
|
||||
return snap
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
import urwid
|
||||
import trio
|
||||
import json
|
||||
|
||||
|
||||
class WorkerMonitor:
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
self.header_info = {}
|
||||
|
||||
self.palette = [
|
||||
('headerbar', 'white', 'dark blue'),
|
||||
('request_row', 'white', 'dark gray'),
|
||||
('worker_row', 'light gray', 'black'),
|
||||
('progress_normal', 'black', 'light gray'),
|
||||
('progress_complete', 'black', 'dark green'),
|
||||
('body', 'white', 'black'),
|
||||
]
|
||||
|
||||
# --- Top bar (header) ---
|
||||
worker_name = self.header_info.get('left', "unknown")
|
||||
balance = self.header_info.get('right', "balance: unknown")
|
||||
|
||||
self.worker_name_widget = urwid.Text(worker_name)
|
||||
self.balance_widget = urwid.Text(balance, align='right')
|
||||
|
||||
header = urwid.Columns([self.worker_name_widget, self.balance_widget])
|
||||
header_attr = urwid.AttrMap(header, 'headerbar')
|
||||
|
||||
# --- Body (List of requests) ---
|
||||
self.body_listbox = self._create_listbox_body(self.requests)
|
||||
|
||||
# --- Bottom bar (progress) ---
|
||||
self.status_text = urwid.Text("Request: none", align='left')
|
||||
self.progress_bar = urwid.ProgressBar(
|
||||
'progress_normal',
|
||||
'progress_complete',
|
||||
current=0,
|
||||
done=100
|
||||
)
|
||||
|
||||
footer_cols = urwid.Columns([
|
||||
('fixed', 20, self.status_text),
|
||||
self.progress_bar,
|
||||
])
|
||||
|
||||
# Build the main frame
|
||||
frame = urwid.Frame(
|
||||
self.body_listbox,
|
||||
header=header_attr,
|
||||
footer=footer_cols
|
||||
)
|
||||
|
||||
# Set up the main loop with Trio
|
||||
self.event_loop = urwid.TrioEventLoop()
|
||||
self.main_loop = urwid.MainLoop(
|
||||
frame,
|
||||
palette=self.palette,
|
||||
event_loop=self.event_loop,
|
||||
unhandled_input=self._exit_on_q
|
||||
)
|
||||
|
||||
def _create_listbox_body(self, requests):
|
||||
"""
|
||||
Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
|
||||
"""
|
||||
widgets = self._build_request_widgets(requests)
|
||||
walker = urwid.SimpleFocusListWalker(widgets)
|
||||
return urwid.ListBox(walker)
|
||||
|
||||
def _build_request_widgets(self, requests):
|
||||
"""
|
||||
Build a list of Urwid widgets (one row per request + per worker).
|
||||
"""
|
||||
row_widgets = []
|
||||
|
||||
for req in requests:
|
||||
# Build a columns widget for the request row
|
||||
columns = urwid.Columns([
|
||||
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
||||
('weight', 3, urwid.Text(req['model'])),
|
||||
('weight', 3, urwid.Text(req['prompt'])),
|
||||
('fixed', 13, urwid.Text(req['user'])),
|
||||
('fixed', 13, urwid.Text(req['reward'])),
|
||||
], dividechars=1)
|
||||
|
||||
# Wrap the columns with an attribute map for coloring
|
||||
request_row = urwid.AttrMap(columns, 'request_row')
|
||||
row_widgets.append(request_row)
|
||||
|
||||
# Then add each worker in its own line below
|
||||
for w in req["workers"]:
|
||||
worker_line = urwid.Text(f" {w}")
|
||||
worker_row = urwid.AttrMap(worker_line, 'worker_row')
|
||||
row_widgets.append(worker_row)
|
||||
|
||||
# Optional blank line after each request
|
||||
row_widgets.append(urwid.Text(""))
|
||||
|
||||
return row_widgets
|
||||
|
||||
def _exit_on_q(self, key):
|
||||
"""Exit the TUI on 'q' or 'Q'."""
|
||||
if key in ('q', 'Q'):
|
||||
raise urwid.ExitMainLoop()
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Run the TUI in an async context (Trio).
|
||||
This method blocks until the user quits (pressing q/Q).
|
||||
"""
|
||||
with self.main_loop.start():
|
||||
await self.event_loop.run_async()
|
||||
|
||||
raise urwid.ExitMainLoop()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Public Methods to Update Various Parts of the UI
|
||||
# -------------------------------------------------------------------------
|
||||
def set_status(self, status: str):
|
||||
self.status_text.set_text(status)
|
||||
|
||||
def set_progress(self, current, done=None):
|
||||
"""
|
||||
Update the bottom progress bar.
|
||||
- `current`: new current progress value (int).
|
||||
- `done`: max progress value (int). If None, we don’t change it.
|
||||
"""
|
||||
if done is not None:
|
||||
self.progress_bar.done = done
|
||||
|
||||
self.progress_bar.current = current
|
||||
|
||||
pct = 0
|
||||
if self.progress_bar.done != 0:
|
||||
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
|
||||
|
||||
def update_requests(self, new_requests):
|
||||
"""
|
||||
Replace the data in the existing ListBox with new request widgets.
|
||||
"""
|
||||
new_widgets = self._build_request_widgets(new_requests)
|
||||
self.body_listbox.body[:] = new_widgets # replace content of the list walker
|
||||
|
||||
def set_header_text(self, new_worker_name=None, new_balance=None):
|
||||
"""
|
||||
Update the text in the header bar for worker name and/or balance.
|
||||
"""
|
||||
if new_worker_name is not None:
|
||||
self.worker_name_widget.set_text(new_worker_name)
|
||||
if new_balance is not None:
|
||||
self.balance_widget.set_text(new_balance)
|
||||
|
||||
def network_update(self, snapshot: dict):
|
||||
queue = [
|
||||
{
|
||||
**r,
|
||||
**(json.loads(r['body'])['params']),
|
||||
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
|
||||
}
|
||||
for r in snapshot['queue']
|
||||
]
|
||||
self.update_requests(queue)
|
||||
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Example usage
|
||||
# # -----------------------------------------------------------------------------
|
||||
#
|
||||
# async def main():
|
||||
# # Example data
|
||||
# example_requests = [
|
||||
# {
|
||||
# "id": 12,
|
||||
# "model": "black-forest-labs/FLUX.1-schnell",
|
||||
# "prompt": "Generate an answer about quantum entanglement.",
|
||||
# "user": "alice123",
|
||||
# "reward": "20.0000 GPU",
|
||||
# "workers": ["workerA", "workerB"],
|
||||
# },
|
||||
# {
|
||||
# "id": 5,
|
||||
# "model": "some-other-model/v2.0",
|
||||
# "prompt": "A story about dragons.",
|
||||
# "user": "bobthebuilder",
|
||||
# "reward": "15.0000 GPU",
|
||||
# "workers": ["workerX"],
|
||||
# },
|
||||
# {
|
||||
# "id": 99,
|
||||
# "model": "cool-model/turbo",
|
||||
# "prompt": "Classify sentiment in these tweets.",
|
||||
# "user": "charlie",
|
||||
# "reward": "25.5000 GPU",
|
||||
# "workers": ["workerOne", "workerTwo", "workerThree"],
|
||||
# },
|
||||
# ]
|
||||
#
|
||||
# ui = WorkerMonitor()
|
||||
#
|
||||
# async def progress_task():
|
||||
# # Fill from 0% to 100%
|
||||
# for pct in range(101):
|
||||
# ui.set_progress(
|
||||
# current=pct,
|
||||
# status_str=f"Request #1234 ({pct}%)"
|
||||
# )
|
||||
# await trio.sleep(0.05)
|
||||
# # Reset to 0
|
||||
# ui.set_progress(
|
||||
# current=0,
|
||||
# status_str="Starting again..."
|
||||
# )
|
||||
#
|
||||
# async def update_data_task():
|
||||
# await trio.sleep(3) # Wait a bit, then update requests
|
||||
# new_data = [{
|
||||
# "id": 101,
|
||||
# "model": "new-model/v1.0",
|
||||
# "prompt": "Say hi to the world.",
|
||||
# "user": "eve",
|
||||
# "reward": "50.0000 GPU",
|
||||
# "workers": ["workerFresh", "workerPower"],
|
||||
# }]
|
||||
# ui.update_requests(new_data)
|
||||
# ui.set_header_text(new_worker_name="NewNodeName",
|
||||
# new_balance="balance: 12345.6789 GPU")
|
||||
#
|
||||
# try:
|
||||
# async with trio.open_nursery() as nursery:
|
||||
# # Run the TUI
|
||||
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
|
||||
#
|
||||
# ui.update_requests(example_requests)
|
||||
# ui.set_header_text(
|
||||
# new_worker_name="worker1.scd",
|
||||
# new_balance="balance: 12345.6789 GPU"
|
||||
# )
|
||||
# # Start background tasks
|
||||
# nursery.start_soon(progress_task)
|
||||
# nursery.start_soon(update_data_task)
|
||||
#
|
||||
# except *KeyboardInterrupt as ex_group:
|
||||
# ...
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# trio.run(main)
|
|
@ -7,8 +7,10 @@ import logging
|
|||
import importlib
|
||||
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import diffusers
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
@ -74,12 +76,27 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
|||
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
|
||||
|
||||
|
||||
class DummyPB:
|
||||
def update(self):
|
||||
...
|
||||
|
||||
@torch.compiler.disable
|
||||
@contextmanager
|
||||
def dummy_progress_bar(*args, **kwargs):
|
||||
yield DummyPB()
|
||||
|
||||
|
||||
def monkey_patch_pipeline_disable_progress_bar(pipe):
|
||||
pipe.progress_bar = dummy_progress_bar
|
||||
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
diffusers.utils.logging.disable_progress_bar()
|
||||
|
||||
logging.info(f'pipeline_for {model} {mode}')
|
||||
assert torch.cuda.is_available()
|
||||
|
@ -105,7 +122,9 @@ def pipeline_for(
|
|||
normalized_shortname = shortname.replace('-', '_')
|
||||
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
|
||||
assert custom_pipeline.__model['name'] == model
|
||||
return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
||||
pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
||||
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||
return pipe
|
||||
|
||||
except ImportError:
|
||||
# TODO, uhh why not warn/error log this?
|
||||
|
@ -121,7 +140,6 @@ def pipeline_for(
|
|||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||
|
||||
params = {
|
||||
'safety_checker': None,
|
||||
'torch_dtype': torch.float16,
|
||||
'cache_dir': cache_dir,
|
||||
'variant': 'fp16',
|
||||
|
@ -130,6 +148,7 @@ def pipeline_for(
|
|||
match shortname:
|
||||
case 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
params['safety_checker'] = None
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
||||
|
@ -167,6 +186,8 @@ def pipeline_for(
|
|||
|
||||
pipe = pipe.to('cuda')
|
||||
|
||||
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
|
|
24
uv.lock
24
uv.lock
|
@ -2261,6 +2261,7 @@ cuda = [
|
|||
{ name = "torchvision" },
|
||||
{ name = "transformers" },
|
||||
{ name = "triton" },
|
||||
{ name = "urwid" },
|
||||
{ name = "xformers" },
|
||||
]
|
||||
dev = [
|
||||
|
@ -2312,6 +2313,7 @@ cuda = [
|
|||
{ name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" },
|
||||
{ name = "transformers", specifier = "==4.48.0" },
|
||||
{ name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" },
|
||||
{ name = "urwid", specifier = ">=2.6.16" },
|
||||
{ name = "xformers", specifier = ">=0.0.29,<0.0.30" },
|
||||
]
|
||||
dev = [
|
||||
|
@ -2626,6 +2628,28 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urwid"
|
||||
version = "2.6.16"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/21/ad23c9e961b2d36d57c63686a6f86768dd945d406323fb58c84f09478530/urwid-2.6.16.tar.gz", hash = "sha256:93ad239939e44c385e64aa00027878b9e5c486d59e855ec8ab5b1e1adcdb32a2", size = 848179 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/cb/271a4f5a1bf4208dbdc96d85b9eae744cf4e5e11ac73eda76dc98c8fd2d7/urwid-2.6.16-py3-none-any.whl", hash = "sha256:de14896c6df9eb759ed1fd93e0384a5279e51e0dde8f621e4083f7a8368c0797", size = 297196 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websocket-client"
|
||||
version = "1.8.0"
|
||||
|
|
Loading…
Reference in New Issue