mirror of https://github.com/skygpu/skynet.git
Merge 868e2489b3
into e88792c9d6
commit
631a2295f9
|
@ -11,3 +11,4 @@ docs
|
|||
ipfs-docker-data
|
||||
ipfs-staging
|
||||
weights
|
||||
*.png
|
||||
|
|
|
@ -29,9 +29,6 @@ poetry shell
|
|||
# test you can run this command
|
||||
skynet --help
|
||||
|
||||
# launch ipfs node
|
||||
skynet run ipfs
|
||||
|
||||
# to launch worker
|
||||
skynet run dgpu
|
||||
|
||||
|
@ -77,9 +74,6 @@ docker pull guilledk/skynet:runtime-cuda
|
|||
# or build it (takes a bit of time)
|
||||
./build_docker.sh
|
||||
|
||||
# launch simple ipfs node
|
||||
./launch_ipfs.sh
|
||||
|
||||
# run worker with all gpus
|
||||
docker run \
|
||||
-it \
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
from python:3.11
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y \
|
||||
git
|
||||
|
||||
run curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
env PATH "/root/.local/bin:$PATH"
|
||||
|
||||
copy . /skynet
|
||||
|
||||
workdir /skynet
|
||||
|
||||
env POETRY_VIRTUALENVS_PATH /skynet/.venv
|
||||
|
||||
run poetry install
|
||||
|
||||
workdir /root/target
|
||||
|
||||
copy docker/entrypoint.sh /entrypoint.sh
|
||||
entrypoint ["/entrypoint.sh"]
|
||||
|
||||
cmd ["skynet", "--help"]
|
|
@ -1,46 +0,0 @@
|
|||
from nvidia/cuda:11.8.0-devel-ubuntu20.04
|
||||
from python:3.10
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y \
|
||||
git \
|
||||
clang \
|
||||
cmake \
|
||||
ffmpeg \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
ninja-build
|
||||
|
||||
env CC /usr/bin/clang
|
||||
env CXX /usr/bin/clang++
|
||||
|
||||
# install llvm10 as required by llvm-lite
|
||||
run git clone https://github.com/llvm/llvm-project.git -b llvmorg-10.0.1
|
||||
workdir /llvm-project
|
||||
# this adds a commit from 12.0.0 that fixes build on newer compilers
|
||||
run git cherry-pick -n b498303066a63a203d24f739b2d2e0e56dca70d1
|
||||
run cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release
|
||||
run ninja -C build install # -j8
|
||||
|
||||
run curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
env PATH "/root/.local/bin:$PATH"
|
||||
|
||||
copy . /skynet
|
||||
|
||||
workdir /skynet
|
||||
|
||||
env POETRY_VIRTUALENVS_PATH /skynet/.venv
|
||||
|
||||
run poetry install --with=cuda -v
|
||||
|
||||
workdir /root/target
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
|
||||
copy docker/entrypoint.sh /entrypoint.sh
|
||||
entrypoint ["/entrypoint.sh"]
|
||||
|
||||
cmd ["skynet", "--help"]
|
|
@ -1,46 +0,0 @@
|
|||
from nvidia/cuda:11.8.0-devel-ubuntu20.04
|
||||
from python:3.11
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y \
|
||||
git \
|
||||
clang \
|
||||
cmake \
|
||||
ffmpeg \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
ninja-build
|
||||
|
||||
env CC /usr/bin/clang
|
||||
env CXX /usr/bin/clang++
|
||||
|
||||
# install llvm10 as required by llvm-lite
|
||||
run git clone https://github.com/llvm/llvm-project.git -b llvmorg-10.0.1
|
||||
workdir /llvm-project
|
||||
# this adds a commit from 12.0.0 that fixes build on newer compilers
|
||||
run git cherry-pick -n b498303066a63a203d24f739b2d2e0e56dca70d1
|
||||
run cmake -S llvm -B build -G Ninja -DCMAKE_BUILD_TYPE=Release
|
||||
run ninja -C build install # -j8
|
||||
|
||||
run curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
env PATH "/root/.local/bin:$PATH"
|
||||
|
||||
copy . /skynet
|
||||
|
||||
workdir /skynet
|
||||
|
||||
env POETRY_VIRTUALENVS_PATH /skynet/.venv
|
||||
|
||||
run poetry install --with=cuda -v
|
||||
|
||||
workdir /root/target
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:128
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
|
||||
copy docker/entrypoint.sh /entrypoint.sh
|
||||
entrypoint ["/entrypoint.sh"]
|
||||
|
||||
cmd ["skynet", "--help"]
|
|
@ -0,0 +1,43 @@
|
|||
from nvidia/cuda:12.4.1-devel-ubuntu22.04
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y \
|
||||
git \
|
||||
curl \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libglu1-mesa \
|
||||
libgl1-mesa-glx
|
||||
|
||||
env PATH="/opt/uv:$PATH"
|
||||
|
||||
arg USER_ID=1000
|
||||
arg GROUP_ID=1000
|
||||
|
||||
run groupadd -g $GROUP_ID skynet \
|
||||
&& useradd -l -u $USER_ID -g skynet -s /bin/bash skynet \
|
||||
&& mkdir -p /home/skynet \
|
||||
&& chown -R skynet:skynet /home/skynet
|
||||
|
||||
run curl -LsSf https://astral.sh/uv/install.sh | env UV_UNMANAGED_INSTALL="/opt/uv" sh
|
||||
|
||||
run chown -R skynet:skynet /opt/uv
|
||||
|
||||
run ls /opt/uv -lah
|
||||
|
||||
user skynet
|
||||
|
||||
workdir /home/skynet
|
||||
|
||||
run uv venv --python 3.12
|
||||
|
||||
workdir /home/skynet/target
|
||||
|
||||
env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
|
||||
env NVIDIA_VISIBLE_DEVICES=all
|
||||
|
||||
copy docker/entrypoint.sh /entrypoint.sh
|
||||
entrypoint ["/entrypoint.sh"]
|
||||
|
||||
cmd ["skynet", "--help"]
|
|
@ -1,25 +0,0 @@
|
|||
from python:3.11
|
||||
|
||||
env DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
run apt-get update && apt-get install -y \
|
||||
git
|
||||
|
||||
run curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
env PATH "/root/.local/bin:$PATH"
|
||||
|
||||
copy . /skynet
|
||||
|
||||
workdir /skynet
|
||||
|
||||
env POETRY_VIRTUALENVS_PATH /skynet/.venv
|
||||
|
||||
run poetry install --with=frontend -v
|
||||
|
||||
workdir /root/target
|
||||
|
||||
copy docker/entrypoint.sh /entrypoint.sh
|
||||
entrypoint ["/entrypoint.sh"]
|
||||
|
||||
cmd ["skynet", "--help"]
|
|
@ -1,20 +1,3 @@
|
|||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime \
|
||||
-f docker/Dockerfile.runtime .
|
||||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime-frontend \
|
||||
-f docker/Dockerfile.runtime+frontend .
|
||||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime-cuda-py311 \
|
||||
-f docker/Dockerfile.runtime+cuda-py311 .
|
||||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime-cuda \
|
||||
-f docker/Dockerfile.runtime+cuda-py311 .
|
||||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime-cuda-py310 \
|
||||
-f docker/Dockerfile.runtime+cuda-py310 .
|
||||
-t guilledk/skynet:runtime-cuda-py312 \
|
||||
-f docker/Dockerfile.runtime+cuda-py312 . --progress=plain
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
#!/bin/sh
|
||||
|
||||
export VIRTUAL_ENV='/skynet/.venv'
|
||||
poetry env use $VIRTUAL_ENV/bin/python
|
||||
uv sync
|
||||
|
||||
poetry install
|
||||
|
||||
exec poetry run "$@"
|
||||
exec uv run "$@"
|
||||
|
|
|
@ -1,5 +1 @@
|
|||
docker push guilledk/skynet:runtime
|
||||
docker push guilledk/skynet:runtime-frontend
|
||||
docker push guilledk/skynet:runtime-cuda
|
||||
docker push guilledk/skynet:runtime-cuda-py311
|
||||
docker push guilledk/skynet:runtime-cuda-py310
|
||||
docker push guilledk/skynet:runtime-cuda-py312
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
name='skynet-ipfs'
|
||||
peers=("$@")
|
||||
|
||||
data_dir="$(pwd)/ipfs-docker-data"
|
||||
data_target='/data/ipfs'
|
||||
|
||||
# Create data directory if it doesn't exist
|
||||
mkdir -p "$data_dir"
|
||||
|
||||
# Run the container
|
||||
docker run -d \
|
||||
--name "$name" \
|
||||
-p 8080:8080/tcp \
|
||||
-p 4001:4001/tcp \
|
||||
-p 127.0.0.1:5001:5001/tcp \
|
||||
--mount type=bind,source="$data_dir",target="$data_target" \
|
||||
--rm \
|
||||
ipfs/go-ipfs:latest
|
||||
|
||||
# Change ownership
|
||||
docker exec "$name" chown 1000:1000 -R "$data_target"
|
||||
|
||||
# Wait for Daemon to be ready
|
||||
while read -r log; do
|
||||
echo "$log"
|
||||
if [[ "$log" == *"Daemon is ready"* ]]; then
|
||||
break
|
||||
fi
|
||||
done < <(docker logs -f "$name")
|
||||
|
||||
# Connect to peers
|
||||
for peer in "${peers[@]}"; do
|
||||
docker exec "$name" ipfs swarm connect "$peer" || echo "Error connecting to peer: $peer"
|
||||
done
|
File diff suppressed because it is too large
Load Diff
|
@ -1,2 +0,0 @@
|
|||
[virtualenvs]
|
||||
in-project = true
|
138
pyproject.toml
138
pyproject.toml
|
@ -1,67 +1,85 @@
|
|||
[tool.poetry]
|
||||
name = 'skynet'
|
||||
version = '0.1a12'
|
||||
description = 'Decentralized compute platform'
|
||||
authors = ['Guillermo Rodriguez <guillermo@telos.net>']
|
||||
license = 'AGPL'
|
||||
readme = 'README.md'
|
||||
[project]
|
||||
name = "skynet"
|
||||
version = "0.1a13"
|
||||
description = "Decentralized compute platform"
|
||||
authors = [{ name = "Guillermo Rodriguez", email = "guillermo@telos.net" }]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
readme = "README.md"
|
||||
license = "AGPL-3.0-or-later"
|
||||
dependencies = [
|
||||
"pytz~=2023.3.post1",
|
||||
"trio>=0.22.2,<0.23",
|
||||
"Pillow>=10.0.1,<11",
|
||||
"docker>=6.1.3,<7",
|
||||
"py-leap",
|
||||
"toml>=0.10.2,<0.11",
|
||||
"msgspec>=0.19.0,<0.20",
|
||||
"numpy<2.1",
|
||||
"protobuf>=5.29.3,<6",
|
||||
"click>=8.1.8,<9",
|
||||
"httpx>=0.28.1,<0.29",
|
||||
"outcome>=1.3.0.post0",
|
||||
"urwid>=2.6.16",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = '>=3.10,<3.12'
|
||||
pytz = '^2023.3.post1'
|
||||
trio = '^0.22.2'
|
||||
asks = '^3.0.0'
|
||||
Pillow = '^10.0.1'
|
||||
docker = '^6.1.3'
|
||||
py-leap = {git = 'https://github.com/guilledk/py-leap.git', rev = 'v0.1a14'}
|
||||
toml = '^0.10.2'
|
||||
[project.scripts]
|
||||
skynet = "skynet.cli:skynet"
|
||||
txt2img = "skynet.cli:txt2img"
|
||||
img2img = "skynet.cli:img2img"
|
||||
upscale = "skynet.cli:upscale"
|
||||
inpaint = "skynet.cli:inpaint"
|
||||
|
||||
[tool.poetry.group.frontend]
|
||||
optional = true
|
||||
[dependency-groups]
|
||||
frontend = [
|
||||
"triopg>=0.6.0,<0.7",
|
||||
"aiohttp>=3.8.5,<4",
|
||||
"psycopg2-binary>=2.9.7,<3",
|
||||
"pyTelegramBotAPI>=4.14.0,<5",
|
||||
"discord.py>=2.3.2,<3",
|
||||
]
|
||||
dev = [
|
||||
"pdbpp>=0.10.3,<0.11",
|
||||
"pytest>=7.4.2,<8",
|
||||
"pytest-dockerctl",
|
||||
"pytest-trio>=0.8.0,<0.9",
|
||||
]
|
||||
cuda = [
|
||||
"torch==2.5.1+cu121",
|
||||
"scipy==1.15.1",
|
||||
"numba==0.60.0",
|
||||
# "triton==3.1.0",
|
||||
# "xformers>=0.0.29,<0.0.30",
|
||||
"diffusers==0.32.1",
|
||||
"torchvision==0.20.1+cu121",
|
||||
"accelerate==0.34.0",
|
||||
"transformers==4.48.0",
|
||||
"huggingface-hub>=0.27.1,<0.28",
|
||||
"invisible-watermark>=0.2.0,<0.3",
|
||||
"bitsandbytes>=0.45.0,<0.46",
|
||||
"basicsr>=1.4.2,<2",
|
||||
"realesrgan>=0.3.0,<0.4",
|
||||
"sentencepiece>=0.2.0",
|
||||
]
|
||||
|
||||
[tool.poetry.group.frontend.dependencies]
|
||||
triopg = {version = '^0.6.0'}
|
||||
aiohttp = {version = '^3.8.5'}
|
||||
psycopg2-binary = {version = '^2.9.7'}
|
||||
pyTelegramBotAPI = {version = '^4.14.0'}
|
||||
'discord.py' = {version = '^2.3.2'}
|
||||
[tool.uv]
|
||||
default-groups = [
|
||||
"frontend",
|
||||
"dev",
|
||||
"cuda",
|
||||
]
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
[[tool.uv.index]]
|
||||
name = "torch"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
explicit = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pdbpp = {version = '^0.10.3'}
|
||||
pytest = {version = '^7.4.2'}
|
||||
|
||||
[tool.poetry.group.cuda]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.cuda.dependencies]
|
||||
torch = {version = '2.0.1+cu118', source = 'torch'}
|
||||
scipy = {version = '^1.11.2'}
|
||||
numba = {version = '0.57.0'}
|
||||
quart = {version = '^0.19.3'}
|
||||
triton = {version = '2.0.0', source = 'torch'}
|
||||
basicsr = {version = '^1.4.2'}
|
||||
xformers = {version = '^0.0.22'}
|
||||
hypercorn = {version = '^0.14.4'}
|
||||
diffusers = {version = '^0.21.2'}
|
||||
realesrgan = {version = '^0.3.0'}
|
||||
quart-trio = {version = '^0.11.0'}
|
||||
torchvision = {version = '0.15.2+cu118', source = 'torch'}
|
||||
accelerate = {version = '^0.23.0'}
|
||||
transformers = {version = '^4.33.2'}
|
||||
huggingface-hub = {version = '^0.17.3'}
|
||||
invisible-watermark = {version = '^0.2.0'}
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = 'torch'
|
||||
url = 'https://download.pytorch.org/whl/cu118'
|
||||
priority = 'explicit'
|
||||
[tool.uv.sources]
|
||||
torch = { index = "torch" }
|
||||
triton = { index = "torch" }
|
||||
torchvision = { index = "torch" }
|
||||
py-leap = { git = "https://github.com/guilledk/py-leap.git", rev = "v0.1a34" }
|
||||
pytest-dockerctl = { git = "https://github.com/pikers/pytest-dockerctl.git", branch = "g_update" }
|
||||
|
||||
[build-system]
|
||||
requires = ['poetry-core', 'cython']
|
||||
build-backend = 'poetry.core.masonry.api'
|
||||
|
||||
[tool.poetry.scripts]
|
||||
skynet = 'skynet.cli:skynet'
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
|
|
@ -1,45 +1,42 @@
|
|||
# config sections are optional, depending on which services
|
||||
# you wish to run
|
||||
|
||||
[skynet.dgpu]
|
||||
[dgpu]
|
||||
account = 'testworkerX'
|
||||
permission = 'active'
|
||||
key = '5Xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
node_url = 'https://testnet.skygpu.net'
|
||||
hyperion_url = 'https://testnet.skygpu.net'
|
||||
ipfs_gateway_url = '/ip4/169.197.140.154/tcp/4001/p2p/12D3KooWKWogLFNEcNNMKnzU7Snrnuj84RZdMBg3sLiQSQc51oEv'
|
||||
ipfs_url = 'http://127.0.0.1:5001'
|
||||
hf_home = 'hf_home'
|
||||
hf_token = 'hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx'
|
||||
auto_withdraw = true
|
||||
non_compete = []
|
||||
api_bind = '127.0.0.1:42690'
|
||||
tui = true
|
||||
log_file = 'dgpu.log'
|
||||
log_level = 'info'
|
||||
|
||||
[skynet.telegram]
|
||||
[telegram]
|
||||
account = 'telegram'
|
||||
permission = 'active'
|
||||
key = '5Xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
node_url = 'https://testnet.skygpu.net'
|
||||
hyperion_url = 'https://testnet.skygpu.net'
|
||||
ipfs_gateway_url = '/ip4/169.197.140.154/tcp/4001/p2p/12D3KooWKWogLFNEcNNMKnzU7Snrnuj84RZdMBg3sLiQSQc51oEv'
|
||||
ipfs_url = 'http://127.0.0.1:5001'
|
||||
token = 'XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
|
||||
[skynet.discord]
|
||||
[discord]
|
||||
account = 'discord'
|
||||
permission = 'active'
|
||||
key = '5Xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
node_url = 'https://testnet.skygpu.net'
|
||||
hyperion_url = 'https://testnet.skygpu.net'
|
||||
ipfs_gateway_url = '/ip4/169.197.140.154/tcp/4001/p2p/12D3KooWKWogLFNEcNNMKnzU7Snrnuj84RZdMBg3sLiQSQc51oEv'
|
||||
ipfs_url = 'http://127.0.0.1:5001'
|
||||
token = 'XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
|
||||
[skynet.pinner]
|
||||
[pinner]
|
||||
hyperion_url = 'https://testnet.skygpu.net'
|
||||
ipfs_url = 'http://127.0.0.1:5001'
|
||||
|
||||
[skynet.user]
|
||||
[user]
|
||||
account = 'testuser'
|
||||
permission = 'active'
|
||||
key = '5Xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
378
skynet/cli.py
378
skynet/cli.py
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
@ -8,10 +6,24 @@ from functools import partial
|
|||
|
||||
import click
|
||||
|
||||
from leap.sugar import Name, asset_from_str
|
||||
from leap.protocol import (
|
||||
Name,
|
||||
Asset,
|
||||
)
|
||||
|
||||
from .config import *
|
||||
from .constants import *
|
||||
from .config import (
|
||||
load_skynet_toml,
|
||||
set_hf_vars,
|
||||
ConfigParsingError,
|
||||
)
|
||||
from .constants import (
|
||||
# TODO, more conventional to make these private i'm pretty
|
||||
# sure according to pep8?
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
DEFAULT_EXPLORER_DOMAIN,
|
||||
DEFAULT_CONFIG_PATH,
|
||||
MODELS,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
|
@ -20,9 +32,12 @@ def skynet(*args, **kwargs):
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default='midj')
|
||||
@click.option('--model', '-m', default=list(MODELS.keys())[-1])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
'--prompt',
|
||||
'-p',
|
||||
default='a red old tractor in a sunny wheat field',
|
||||
)
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--width', '-w', default=512)
|
||||
@click.option('--height', '-h', default=512)
|
||||
|
@ -30,18 +45,24 @@ def skynet(*args, **kwargs):
|
|||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def txt2img(*args, **kwargs):
|
||||
from . import utils
|
||||
from skynet.dgpu import utils
|
||||
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.txt2img(hf_token, **kwargs)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default=list(MODELS.keys())[0])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
'--model',
|
||||
'-m',
|
||||
default=list(MODELS.keys())[-2]
|
||||
)
|
||||
@click.option(
|
||||
'--prompt',
|
||||
'-p',
|
||||
default='a red old tractor in a sunny wheat field',
|
||||
)
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--strength', '-Z', default=1.0)
|
||||
|
@ -49,11 +70,9 @@ def txt2img(*args, **kwargs):
|
|||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||
from . import utils
|
||||
from skynet.dgpu import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.img2img(
|
||||
hf_token,
|
||||
model=model,
|
||||
|
@ -66,12 +85,41 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
|||
seed=seed
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default=list(MODELS.keys())[-3])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--mask', '-M', default='mask.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--strength', '-Z', default=1.0)
|
||||
@click.option('--guidance', '-g', default=10.0)
|
||||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed):
|
||||
from skynet.dgpu import utils
|
||||
config = load_skynet_toml()
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.inpaint(
|
||||
hf_token,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
mask_path=mask,
|
||||
output=output,
|
||||
strength=strength,
|
||||
guidance=guidance,
|
||||
steps=steps,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@click.command()
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--model', '-m', default='weights/RealESRGAN_x4plus.pth')
|
||||
def upscale(input, output, model):
|
||||
from . import utils
|
||||
from skynet.dgpu import utils
|
||||
utils.upscale(
|
||||
img_path=input,
|
||||
output=output,
|
||||
|
@ -80,120 +128,23 @@ def upscale(input, output, model):
|
|||
|
||||
@skynet.command()
|
||||
def download():
|
||||
from . import utils
|
||||
from skynet.dgpu import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.download_all_models(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.download_all_models(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
|
||||
@skynet.command()
|
||||
@click.option(
|
||||
'--reward', '-r', default='20.0000 GPU')
|
||||
@click.option('--jobs', '-j', default=1)
|
||||
@click.option('--model', '-m', default='stabilityai/stable-diffusion-xl-base-1.0')
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--width', '-w', default=1024)
|
||||
@click.option('--height', '-h', default=1024)
|
||||
@click.option('--guidance', '-g', default=10)
|
||||
@click.option('--step', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
@click.option('--upscaler', '-U', default='x4')
|
||||
@click.option('--binary_data', '-b', default='')
|
||||
@click.option('--strength', '-Z', default=None)
|
||||
def enqueue(
|
||||
reward: str,
|
||||
jobs: int,
|
||||
**kwargs
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
binary = kwargs['binary_data']
|
||||
if not kwargs['strength']:
|
||||
if binary:
|
||||
raise ValueError('strength -Z param required if binary data passed')
|
||||
|
||||
del kwargs['strength']
|
||||
|
||||
else:
|
||||
kwargs['strength'] = float(kwargs['strength'])
|
||||
|
||||
async def enqueue_n_jobs():
|
||||
for i in range(jobs):
|
||||
if not kwargs['seed']:
|
||||
kwargs['seed'] = random.randint(0, 10e9)
|
||||
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': kwargs
|
||||
})
|
||||
|
||||
res = await cleos.a_push_action(
|
||||
'telos.gpu',
|
||||
'enqueue',
|
||||
{
|
||||
'user': Name(account),
|
||||
'request_body': req,
|
||||
'binary_data': binary,
|
||||
'reward': asset_from_str(reward),
|
||||
'min_verification': 1
|
||||
},
|
||||
account, key, permission,
|
||||
)
|
||||
print(res)
|
||||
|
||||
trio.run(enqueue_n_jobs)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='Logging level')
|
||||
def clean(
|
||||
loglevel: str,
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'clean',
|
||||
{},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
|
||||
@skynet.command()
|
||||
def queue():
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = config.user.node_url
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
'code': 'telos.gpu',
|
||||
'code': 'gpu.scd',
|
||||
'table': 'queue',
|
||||
'scope': 'telos.gpu',
|
||||
'scope': 'gpu.scd',
|
||||
'json': True
|
||||
}
|
||||
)
|
||||
|
@ -204,11 +155,11 @@ def queue():
|
|||
def status(request_id: int):
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = config.user.node_url
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
'code': 'telos.gpu',
|
||||
'code': 'gpu.scd',
|
||||
'table': 'status',
|
||||
'scope': request_id,
|
||||
'json': True
|
||||
|
@ -216,100 +167,6 @@ def status(request_id: int):
|
|||
)
|
||||
print(json.dumps(resp.json(), indent=4))
|
||||
|
||||
@skynet.command()
|
||||
@click.argument('request-id')
|
||||
def dequeue(request_id: int):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'dequeue',
|
||||
{
|
||||
'user': Name(account),
|
||||
'request_id': int(request_id),
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.option(
|
||||
'--token-contract', '-c', default='eosio.token')
|
||||
@click.option(
|
||||
'--token-symbol', '-S', default='4,GPU')
|
||||
def config(
|
||||
token_contract: str,
|
||||
token_symbol: str
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'config',
|
||||
{
|
||||
'token_contract': token_contract,
|
||||
'token_symbol': token_symbol,
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.argument('quantity')
|
||||
def deposit(quantity: str):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'transfer',
|
||||
{
|
||||
'sender': Name(account),
|
||||
'recipient': Name('telos.gpu'),
|
||||
'amount': asset_from_str(quantity),
|
||||
'memo': f'{account} transferred {quantity} to telos.gpu'
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.group()
|
||||
def run(*args, **kwargs):
|
||||
pass
|
||||
|
@ -323,36 +180,27 @@ def db():
|
|||
container, passwd, host = db_params
|
||||
logging.info(('skynet', passwd, host))
|
||||
|
||||
@run.command()
|
||||
def nodeos():
|
||||
from .nodeos import open_nodeos
|
||||
|
||||
logging.basicConfig(filename='skynet-nodeos.log', level=logging.INFO)
|
||||
with open_nodeos(cleanup=False):
|
||||
...
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='Logging level')
|
||||
@click.option(
|
||||
'--config-path', '-c', default=DEFAULT_CONFIG_PATH)
|
||||
'--config-path',
|
||||
'-c',
|
||||
default=DEFAULT_CONFIG_PATH,
|
||||
)
|
||||
def dgpu(
|
||||
loglevel: str,
|
||||
config_path: str
|
||||
):
|
||||
import trio
|
||||
from .dgpu import open_dgpu_node
|
||||
from .dgpu import _dgpu_main
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml(file_path=config_path)
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
config = load_skynet_toml(file_path=config_path).dgpu
|
||||
set_hf_vars(config.hf_token, config.hf_home)
|
||||
|
||||
assert 'skynet' in config
|
||||
assert 'dgpu' in config['skynet']
|
||||
|
||||
trio.run(open_dgpu_node, config['skynet']['dgpu'])
|
||||
trio.run(_dgpu_main, config)
|
||||
|
||||
|
||||
@run.command()
|
||||
|
@ -375,30 +223,24 @@ def telegram(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = load_key(config, 'skynet.telegram.tg_token')
|
||||
tg_token = config.telegram.tg_token
|
||||
|
||||
key = load_key(config, 'skynet.telegram.key')
|
||||
account = load_key(config, 'skynet.telegram.account')
|
||||
permission = load_key(config, 'skynet.telegram.permission')
|
||||
node_url = load_key(config, 'skynet.telegram.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
|
||||
key = config.telegram.key
|
||||
account = config.telegram.account
|
||||
permission = config.telegram.permission
|
||||
node_url = config.telegram.node_url
|
||||
hyperion_url = config.telegram.hyperion_url
|
||||
|
||||
ipfs_url = config.telegram.ipfs_url
|
||||
|
||||
try:
|
||||
ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_gateway_url = None
|
||||
|
||||
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.telegram.explorer_domain')
|
||||
explorer_domain = config.telegram.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.telegram.ipfs_domain')
|
||||
ipfs_domain = config.telegram.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -412,7 +254,6 @@ def telegram(
|
|||
hyperion_url,
|
||||
db_host, db_user, db_pass,
|
||||
ipfs_url,
|
||||
remote_ipfs_node=ipfs_gateway_url,
|
||||
key=key,
|
||||
explorer_domain=explorer_domain,
|
||||
ipfs_domain=ipfs_domain
|
||||
|
@ -445,25 +286,24 @@ def discord(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
dc_token = load_key(config, 'skynet.discord.dc_token')
|
||||
dc_token = config.discord.dc_token
|
||||
|
||||
key = load_key(config, 'skynet.discord.key')
|
||||
account = load_key(config, 'skynet.discord.account')
|
||||
permission = load_key(config, 'skynet.discord.permission')
|
||||
node_url = load_key(config, 'skynet.discord.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
|
||||
key = config.discord.key
|
||||
account = config.discord.account
|
||||
permission = config.discord.permission
|
||||
node_url = config.discord.node_url
|
||||
hyperion_url = config.discord.hyperion_url
|
||||
|
||||
ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
|
||||
ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
|
||||
ipfs_url = config.discord.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.discord.explorer_domain')
|
||||
explorer_domain = config.discord.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.discord.ipfs_domain')
|
||||
ipfs_domain = config.discord.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -477,7 +317,6 @@ def discord(
|
|||
hyperion_url,
|
||||
db_host, db_user, db_pass,
|
||||
ipfs_url,
|
||||
remote_ipfs_node=ipfs_gateway_url,
|
||||
key=key,
|
||||
explorer_domain=explorer_domain,
|
||||
ipfs_domain=ipfs_domain
|
||||
|
@ -489,17 +328,6 @@ def discord(
|
|||
asyncio.run(_async_main())
|
||||
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
|
||||
@click.option('--name', '-n', default='skynet-ipfs', help='container name')
|
||||
@click.option('--peer', '-p', default=(), help='connect to peer', multiple=True, type=str)
|
||||
def ipfs(loglevel, name, peer):
|
||||
from skynet.ipfs.docker import open_ipfs_node
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
with open_ipfs_node(name=name, peers=peer):
|
||||
...
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
|
||||
def pinner(loglevel):
|
||||
|
@ -509,8 +337,8 @@ def pinner(loglevel):
|
|||
from .ipfs.pinner import SkynetPinner
|
||||
|
||||
config = load_skynet_toml()
|
||||
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
|
||||
hyperion_url = config.pinner.hyperion_url
|
||||
ipfs_url = config.pinner.ipfs_url
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
|
|
@ -1,31 +1,67 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
import toml
|
||||
|
||||
from pathlib import Path
|
||||
import msgspec
|
||||
|
||||
from .constants import DEFAULT_CONFIG_PATH
|
||||
from skynet.constants import DEFAULT_CONFIG_PATH, DEFAULT_IPFS_DOMAIN
|
||||
|
||||
|
||||
class ConfigParsingError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
def load_skynet_toml(file_path=DEFAULT_CONFIG_PATH) -> dict:
|
||||
config = toml.load(file_path)
|
||||
return config
|
||||
class DgpuConfig(msgspec.Struct):
|
||||
account: str # worker account name
|
||||
permission: str # account permission name associated with key
|
||||
key: str # private key
|
||||
node_url: str # antelope http api endpoint
|
||||
ipfs_url: str # IPFS node http rpc endpoint
|
||||
hf_token: str # hugging face token
|
||||
ipfs_domain: str = DEFAULT_IPFS_DOMAIN # IPFS Gateway domain
|
||||
hf_home: str = 'hf_home' # hugging face data cache location
|
||||
non_compete: set[str] = set() # set of worker names to not compete in requests
|
||||
model_whitelist: set[str] = set() # only run these models
|
||||
model_blacklist: set[str] = set() # don't run this models
|
||||
backend: str = 'sync-on-thread' # select inference backend
|
||||
tui: bool = False # enable TUI monitor
|
||||
poll_time: float = 0.5 # wait time for polling updates from contract
|
||||
log_level: str = 'info'
|
||||
log_file: str = 'dgpu.log' # log file path (only used when tui = true)
|
||||
|
||||
|
||||
def load_key(config: dict, key: str) -> str:
|
||||
for skey in key.split('.'):
|
||||
if skey not in config:
|
||||
conf_keys = [k for k in config]
|
||||
raise ConfigParsingError(f'key \"{skey}\" not in {conf_keys}')
|
||||
class FrontendConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
token: str
|
||||
|
||||
config = config[skey]
|
||||
|
||||
return config
|
||||
class PinnerConfig(msgspec.Struct):
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
|
||||
|
||||
class UserConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
|
||||
|
||||
class Config(msgspec.Struct):
|
||||
dgpu: DgpuConfig | None = None
|
||||
telegram: FrontendConfig | None = None
|
||||
discord: FrontendConfig | None = None
|
||||
pinner: PinnerConfig | None = None
|
||||
user: UserConfig | None = None
|
||||
|
||||
|
||||
def load_skynet_toml(file_path=DEFAULT_CONFIG_PATH) -> Config:
|
||||
with open(file_path, 'r') as file:
|
||||
return msgspec.toml.decode(file.read(), type=Config)
|
||||
|
||||
|
||||
def set_hf_vars(hf_token: str, hf_home: str):
|
||||
|
|
|
@ -1,34 +1,127 @@
|
|||
#!/usr/bin/python
|
||||
import msgspec
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
VERSION = '0.1a12'
|
||||
|
||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||
|
||||
MODELS = {
|
||||
'prompthero/openjourney': {'short': 'midj', 'mem': 6},
|
||||
'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6},
|
||||
'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6},
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3},
|
||||
'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6},
|
||||
'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6},
|
||||
'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6},
|
||||
'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6},
|
||||
'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6},
|
||||
'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6},
|
||||
'nousr/robo-diffusion': {'short': 'robot', 'mem': 6},
|
||||
|
||||
# default is always last
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3},
|
||||
class ModelDesc(msgspec.Struct):
|
||||
short: str # short unique name
|
||||
mem: float # recomended mem
|
||||
attrs: dict # additional mode specific attrs
|
||||
tags: list[Literal['txt2img', 'img2img', 'inpaint', 'upscale']]
|
||||
|
||||
|
||||
MODELS: dict[str, ModelDesc] = {
|
||||
'RealESRGAN_x4plus': ModelDesc(
|
||||
short='realesrgan',
|
||||
mem=4,
|
||||
attrs={},
|
||||
tags=['upscale']
|
||||
),
|
||||
'runwayml/stable-diffusion-v1-5': ModelDesc(
|
||||
short='stable',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-2-1-base': ModelDesc(
|
||||
short='stable2',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc(
|
||||
short='stablexl0.9',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Linaqruf/anything-v3.0': ModelDesc(
|
||||
short='hdanime',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'hakurei/waifu-diffusion': ModelDesc(
|
||||
short='waifu',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nitrosocke/Ghibli-Diffusion': ModelDesc(
|
||||
short='ghibli',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'dallinmackay/Van-Gogh-diffusion': ModelDesc(
|
||||
short='van-gogh',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'lambdalabs/sd-pokemon-diffusers': ModelDesc(
|
||||
short='pokemon',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'Envvi/Inkpunk-Diffusion': ModelDesc(
|
||||
short='ink',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'nousr/robo-diffusion': ModelDesc(
|
||||
short='robot',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-schnell': ModelDesc(
|
||||
short='flux',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
'black-forest-labs/FLUX.1-Fill-dev': ModelDesc(
|
||||
short='flux-inpaint',
|
||||
mem=24,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc(
|
||||
short='stablexl-inpaint',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
tags=['inpaint']
|
||||
),
|
||||
'prompthero/openjourney': ModelDesc(
|
||||
short='midj',
|
||||
mem=6,
|
||||
attrs={'size': {'w': 512, 'h': 512}},
|
||||
tags=['txt2img', 'img2img']
|
||||
),
|
||||
'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc(
|
||||
short='stablexl',
|
||||
mem=8.3,
|
||||
attrs={'size': {'w': 1024, 'h': 1024}},
|
||||
tags=['txt2img']
|
||||
),
|
||||
}
|
||||
|
||||
SHORT_NAMES = [
|
||||
model_info['short']
|
||||
model_info.short
|
||||
for model_info in MODELS.values()
|
||||
]
|
||||
|
||||
def get_model_by_shortname(short: str):
|
||||
for model, info in MODELS.items():
|
||||
if short == info['short']:
|
||||
if short == info.short:
|
||||
return model
|
||||
|
||||
N = '\n'
|
||||
|
@ -166,9 +259,7 @@ DEFAULT_UPSCALER = None
|
|||
|
||||
DEFAULT_CONFIG_PATH = 'skynet.toml'
|
||||
|
||||
DEFAULT_INITAL_MODELS = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
]
|
||||
DEFAULT_INITAL_MODEL = list(MODELS.keys())[-1]
|
||||
|
||||
DATE_FORMAT = '%B the %dth %Y, %H:%M:%S'
|
||||
|
||||
|
@ -193,3 +284,221 @@ TG_MAX_WIDTH = 1280
|
|||
TG_MAX_HEIGHT = 1280
|
||||
|
||||
DEFAULT_SINGLE_CARD_MAP = 'cuda:0'
|
||||
|
||||
GPU_CONTRACT_ABI = {
|
||||
"version": "eosio::abi/1.2",
|
||||
"types": [],
|
||||
"structs": [
|
||||
{
|
||||
"name": "account",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "balance", "type": "asset"},
|
||||
{"name": "nonce", "type": "uint64"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "card",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "id", "type": "uint64"},
|
||||
{"name": "owner", "type": "name"},
|
||||
{"name": "card_name", "type": "string"},
|
||||
{"name": "version", "type": "string"},
|
||||
{"name": "total_memory", "type": "uint64"},
|
||||
{"name": "mp_count", "type": "uint32"},
|
||||
{"name": "extra", "type": "string"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clean",
|
||||
"base": "",
|
||||
"fields": []
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "token_contract", "type": "name"},
|
||||
{"name": "token_symbol", "type": "symbol"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "dequeue",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "request_id", "type": "uint64"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "enqueue",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "request_body", "type": "string"},
|
||||
{"name": "binary_data", "type": "string"},
|
||||
{"name": "reward", "type": "asset"},
|
||||
{"name": "min_verification", "type": "uint32"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "gcfgstruct",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "token_contract", "type": "name"},
|
||||
{"name": "token_symbol", "type": "symbol"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "submit",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "worker", "type": "name"},
|
||||
{"name": "request_id", "type": "uint64"},
|
||||
{"name": "request_hash", "type": "checksum256"},
|
||||
{"name": "result_hash", "type": "checksum256"},
|
||||
{"name": "ipfs_hash", "type": "string"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "withdraw",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "quantity", "type": "asset"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "work_request_struct",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "id", "type": "uint64"},
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "reward", "type": "asset"},
|
||||
{"name": "min_verification", "type": "uint32"},
|
||||
{"name": "nonce", "type": "uint64"},
|
||||
{"name": "body", "type": "string"},
|
||||
{"name": "binary_data", "type": "string"},
|
||||
{"name": "timestamp", "type": "time_point_sec"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "work_result_struct",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "id", "type": "uint64"},
|
||||
{"name": "request_id", "type": "uint64"},
|
||||
{"name": "user", "type": "name"},
|
||||
{"name": "worker", "type": "name"},
|
||||
{"name": "result_hash", "type": "checksum256"},
|
||||
{"name": "ipfs_hash", "type": "string"},
|
||||
{"name": "submited", "type": "time_point_sec"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "workbegin",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "worker", "type": "name"},
|
||||
{"name": "request_id", "type": "uint64"},
|
||||
{"name": "max_workers", "type": "uint32"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "workcancel",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "worker", "type": "name"},
|
||||
{"name": "request_id", "type": "uint64"},
|
||||
{"name": "reason", "type": "string"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "worker",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "account", "type": "name"},
|
||||
{"name": "joined", "type": "time_point_sec"},
|
||||
{"name": "left", "type": "time_point_sec"},
|
||||
{"name": "url", "type": "string"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "worker_status_struct",
|
||||
"base": "",
|
||||
"fields": [
|
||||
{"name": "worker", "type": "name"},
|
||||
{"name": "status", "type": "string"},
|
||||
{"name": "started", "type": "time_point_sec"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"actions": [
|
||||
{"name": "clean", "type": "clean", "ricardian_contract": ""},
|
||||
{"name": "config", "type": "config", "ricardian_contract": ""},
|
||||
{"name": "dequeue", "type": "dequeue", "ricardian_contract": ""},
|
||||
{"name": "enqueue", "type": "enqueue", "ricardian_contract": ""},
|
||||
{"name": "submit", "type": "submit", "ricardian_contract": ""},
|
||||
{"name": "withdraw", "type": "withdraw", "ricardian_contract": ""},
|
||||
{"name": "workbegin", "type": "workbegin", "ricardian_contract": ""},
|
||||
{"name": "workcancel", "type": "workcancel", "ricardian_contract": ""}
|
||||
],
|
||||
"tables": [
|
||||
{
|
||||
"name": "cards",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "card"
|
||||
},
|
||||
{
|
||||
"name": "gcfgstruct",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "gcfgstruct"
|
||||
},
|
||||
{
|
||||
"name": "queue",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "work_request_struct"
|
||||
},
|
||||
{
|
||||
"name": "results",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "work_result_struct"
|
||||
},
|
||||
{
|
||||
"name": "status",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "worker_status_struct"
|
||||
},
|
||||
{
|
||||
"name": "users",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "account"
|
||||
},
|
||||
{
|
||||
"name": "workers",
|
||||
"index_type": "i64",
|
||||
"key_names": [],
|
||||
"key_types": [],
|
||||
"type": "worker"
|
||||
}
|
||||
],
|
||||
"ricardian_clauses": [],
|
||||
"error_messages": [],
|
||||
"abi_extensions": [],
|
||||
"variants": [],
|
||||
"action_results": []
|
||||
}
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from .functions import open_new_database, open_database_connection
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
|
|
|
@ -1,30 +1,40 @@
|
|||
#!/usr/bin/python
|
||||
import logging
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
import urwid
|
||||
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.trio import serve
|
||||
|
||||
from skynet.dgpu.compute import SkynetMM
|
||||
from skynet.dgpu.daemon import SkynetDGPUDaemon
|
||||
from skynet.dgpu.network import SkynetGPUConnector
|
||||
from skynet.config import Config
|
||||
from skynet.dgpu.tui import init_tui
|
||||
from skynet.dgpu.daemon import dgpu_serve_forever
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
async def open_dgpu_node(config: dict):
|
||||
conn = SkynetGPUConnector(config)
|
||||
mm = SkynetMM(config)
|
||||
daemon = SkynetDGPUDaemon(mm, conn, config)
|
||||
@acm
|
||||
async def open_worker(config: Config):
|
||||
# suppress logs from httpx (logs url + status after every query)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
api = None
|
||||
if 'api_bind' in config:
|
||||
api_conf = Config()
|
||||
api_conf.bind = [config['api_bind']]
|
||||
api = await daemon.generate_api()
|
||||
tui = None
|
||||
if config.tui:
|
||||
tui = init_tui(config)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(daemon.snap_updater_task)
|
||||
conn = NetConnector(config)
|
||||
|
||||
if api:
|
||||
n.start_soon(serve, api, api_conf)
|
||||
try:
|
||||
n: trio.Nursery
|
||||
async with trio.open_nursery() as n:
|
||||
if tui:
|
||||
n.start_soon(tui.run)
|
||||
|
||||
await daemon.serve_forever()
|
||||
n.start_soon(conn.iter_poll_update, config.poll_time)
|
||||
|
||||
yield conn
|
||||
|
||||
except *urwid.ExitMainLoop:
|
||||
...
|
||||
|
||||
|
||||
async def _dgpu_main(config: Config):
|
||||
async with open_worker(config) as conn:
|
||||
await dgpu_serve_forever(config, conn)
|
||||
|
|
|
@ -1,48 +1,64 @@
|
|||
#!/usr/bin/python
|
||||
'''
|
||||
Skynet Memory Manager
|
||||
|
||||
# Skynet Memory Manager
|
||||
'''
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
import zipfile
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import trio
|
||||
import torch
|
||||
|
||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
||||
from skynet.config import load_skynet_toml
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
DGPUInferenceCancelled,
|
||||
)
|
||||
|
||||
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
|
||||
|
||||
def prepare_params_for_diffuse(
|
||||
params: dict,
|
||||
input_type: str,
|
||||
binary = None
|
||||
mode: str,
|
||||
inputs: list[bytes]
|
||||
):
|
||||
_params = {}
|
||||
if binary != None:
|
||||
match input_type:
|
||||
case 'png':
|
||||
image = crop_image(
|
||||
binary, params['width'], params['height'])
|
||||
match mode:
|
||||
case 'inpaint':
|
||||
image = crop_image(
|
||||
inputs[0], params['width'], params['height'])
|
||||
|
||||
_params['image'] = image
|
||||
mask = crop_image(
|
||||
inputs[1], params['width'], params['height'])
|
||||
|
||||
_params['image'] = image
|
||||
_params['mask_image'] = mask
|
||||
|
||||
if 'flux' in params['model'].lower():
|
||||
_params['max_sequence_length'] = 512
|
||||
else:
|
||||
_params['strength'] = float(params['strength'])
|
||||
|
||||
case 'none':
|
||||
...
|
||||
case 'img2img':
|
||||
image = crop_image(
|
||||
inputs[0], params['width'], params['height'])
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
||||
_params['image'] = image
|
||||
_params['strength'] = float(params['strength'])
|
||||
|
||||
else:
|
||||
_params['width'] = int(params['width'])
|
||||
_params['height'] = int(params['height'])
|
||||
case 'txt2img' | 'diffuse':
|
||||
...
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unknown mode {mode}')
|
||||
|
||||
# _params['width'] = int(params['width'])
|
||||
# _params['height'] = int(params['height'])
|
||||
|
||||
return (
|
||||
params['prompt'],
|
||||
|
@ -53,158 +69,140 @@ def prepare_params_for_diffuse(
|
|||
_params
|
||||
)
|
||||
|
||||
_model_name: str = ''
|
||||
_model_mode: str = ''
|
||||
_model = None
|
||||
|
||||
class SkynetMM:
|
||||
@cm
|
||||
def maybe_load_model(name: str, mode: str):
|
||||
if mode == 'diffuse':
|
||||
mode = 'txt2img'
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.upscaler = init_upscaler()
|
||||
self.initial_models = (
|
||||
config['initial_models']
|
||||
if 'initial_models' in config else DEFAULT_INITAL_MODELS
|
||||
)
|
||||
global _model_name, _model_mode, _model
|
||||
config = load_skynet_toml().dgpu
|
||||
|
||||
self.cache_dir = None
|
||||
if 'hf_home' in config:
|
||||
self.cache_dir = config['hf_home']
|
||||
if _model_name != name or _model_mode != mode:
|
||||
# unload model
|
||||
_model = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self._models = {}
|
||||
for model in self.initial_models:
|
||||
self.load_model(model, False, force=True)
|
||||
_model_name = _model_mode = ''
|
||||
|
||||
def log_debug_info(self):
|
||||
logging.info('memory summary:')
|
||||
logging.info('\n' + torch.cuda.memory_summary())
|
||||
|
||||
def is_model_loaded(self, model_name: str, image: bool):
|
||||
for model_key, model_data in self._models.items():
|
||||
if (model_key == model_name and
|
||||
model_data['image'] == image):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_name: str,
|
||||
image: bool,
|
||||
force=False
|
||||
):
|
||||
logging.info(f'loading model {model_name}...')
|
||||
if force or len(self._models.keys()) == 0:
|
||||
pipe = pipeline_for(
|
||||
model_name, image=image, cache_dir=self.cache_dir)
|
||||
|
||||
self._models[model_name] = {
|
||||
'pipe': pipe,
|
||||
'generated': 0,
|
||||
'image': image
|
||||
}
|
||||
# load model
|
||||
if mode == 'upscale':
|
||||
_model = init_upscaler()
|
||||
|
||||
else:
|
||||
least_used = list(self._models.keys())[0]
|
||||
_model = pipeline_for(
|
||||
name, mode, cache_dir=config.hf_home)
|
||||
|
||||
for model in self._models:
|
||||
if self._models[
|
||||
least_used]['generated'] > self._models[model]['generated']:
|
||||
least_used = model
|
||||
_model_name = name
|
||||
_model_mode = mode
|
||||
|
||||
del self._models[least_used]
|
||||
logging.debug('memory summary:')
|
||||
logging.debug('\n' + torch.cuda.memory_summary())
|
||||
|
||||
logging.info(f'swapping model {least_used} for {model_name}...')
|
||||
yield _model
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pipe = pipeline_for(
|
||||
model_name, image=image, cache_dir=self.cache_dir)
|
||||
def compute_one(
|
||||
model,
|
||||
request_id: int,
|
||||
method: str,
|
||||
params: dict,
|
||||
inputs: list[bytes] = [],
|
||||
should_cancel = None
|
||||
):
|
||||
total_steps = params['step'] if 'step' in params else 1
|
||||
def inference_step_wakeup(*args, **kwargs):
|
||||
'''This is a callback function that gets invoked every inference step,
|
||||
we need to raise an exception here if we need to cancel work
|
||||
'''
|
||||
step = args[0]
|
||||
# compat with callback_on_step_end
|
||||
if not isinstance(step, int):
|
||||
step = args[1]
|
||||
|
||||
self._models[model_name] = {
|
||||
'pipe': pipe,
|
||||
'generated': 0,
|
||||
'image': image
|
||||
}
|
||||
maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
|
||||
|
||||
logging.info(f'loaded model {model_name}')
|
||||
return pipe
|
||||
should_raise = False
|
||||
if should_cancel:
|
||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||
|
||||
def get_model(self, model_name: str, image: bool) -> DiffusionPipeline:
|
||||
if model_name not in MODELS:
|
||||
raise DGPUComputeError(f'Unknown model {model_name}')
|
||||
if should_raise:
|
||||
logging.warning(f'CANCELLING work at step {step}')
|
||||
raise DGPUInferenceCancelled('network cancel')
|
||||
|
||||
if not self.is_model_loaded(model_name, image):
|
||||
pipe = self.load_model(model_name, image=image)
|
||||
return {}
|
||||
|
||||
else:
|
||||
pipe = self._models[model_name]['pipe']
|
||||
maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
|
||||
|
||||
return pipe
|
||||
inference_step_wakeup(0)
|
||||
|
||||
def compute_one(
|
||||
self,
|
||||
request_id: int,
|
||||
method: str,
|
||||
params: dict,
|
||||
input_type: str = 'png',
|
||||
binary: bytes | None = None
|
||||
):
|
||||
def maybe_cancel_work(step, *args, **kwargs):
|
||||
if self._should_cancel:
|
||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||
if should_raise:
|
||||
logging.warn(f'cancelling work at step {step}')
|
||||
raise DGPUInferenceCancelled()
|
||||
output_type = 'png'
|
||||
if 'output_type' in params:
|
||||
output_type = params['output_type']
|
||||
|
||||
maybe_cancel_work(0)
|
||||
output = None
|
||||
output_hash = None
|
||||
try:
|
||||
name = params['model']
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in params:
|
||||
output_type = params['output_type']
|
||||
match method:
|
||||
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, method, inputs)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
|
||||
output = None
|
||||
output_hash = None
|
||||
try:
|
||||
match method:
|
||||
case 'diffuse':
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, input_type, binary=binary)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(params['model'], 'image' in extra_params)
|
||||
if 'flux' in name.lower():
|
||||
extra_params['callback_on_step_end'] = inference_step_wakeup
|
||||
|
||||
output = model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
callback=maybe_cancel_work,
|
||||
callback_steps=1,
|
||||
**extra_params
|
||||
).images[0]
|
||||
else:
|
||||
extra_params['callback'] = inference_step_wakeup
|
||||
extra_params['callback_steps'] = 1
|
||||
|
||||
output_binary = b''
|
||||
match output_type:
|
||||
case 'png':
|
||||
if upscaler == 'x4':
|
||||
input_img = output.convert('RGB')
|
||||
up_img, _ = self.upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
output = model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
**extra_params
|
||||
).images[0]
|
||||
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
output_binary = b''
|
||||
match output_type:
|
||||
case 'png':
|
||||
if upscaler == 'x4':
|
||||
input_img = output.convert('RGB')
|
||||
up_img, _ = init_upscaler().enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported output type: {output_type}')
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported output type: {output_type}')
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise DGPUComputeError(str(e))
|
||||
case 'upscale':
|
||||
input_img = inputs[0].convert('RGB')
|
||||
up_img, _ = model.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
|
||||
return output_hash, output
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
except BaseException as err:
|
||||
raise DGPUComputeError(str(err)) from err
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_status(''))
|
||||
|
||||
return output_hash, output
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
|
||||
import trio
|
||||
|
||||
from quart import jsonify
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.constants import MODELS, VERSION
|
||||
|
||||
from skynet.dgpu.errors import *
|
||||
from skynet.dgpu.compute import SkynetMM
|
||||
from skynet.dgpu.network import SkynetGPUConnector
|
||||
from skynet.config import DgpuConfig as Config
|
||||
from skynet.constants import (
|
||||
MODELS,
|
||||
VERSION,
|
||||
)
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
)
|
||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
|
||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
def convert_reward_to_int(reward_str):
|
||||
|
@ -30,197 +29,182 @@ def convert_reward_to_int(reward_str):
|
|||
return int(int_part + decimal_part)
|
||||
|
||||
|
||||
class SkynetDGPUDaemon:
|
||||
async def maybe_update_tui_balance(conn: NetConnector):
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await conn.get_worker_balance()
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mm: SkynetMM,
|
||||
conn: SkynetGPUConnector,
|
||||
config: dict
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
||||
|
||||
async def maybe_serve_one(
|
||||
config: Config,
|
||||
conn: NetConnector,
|
||||
req: dict,
|
||||
):
|
||||
rid = req['id']
|
||||
logging.info(f'maybe serve request #{rid}')
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
model = body['params']['model']
|
||||
|
||||
# if model not known, ignore.
|
||||
if model not in MODELS:
|
||||
logging.warning(f'unknown model {model}!, skip...')
|
||||
return
|
||||
|
||||
# only handle whitelisted models
|
||||
if (
|
||||
len(config.model_whitelist) > 0
|
||||
and
|
||||
model not in config.model_whitelist
|
||||
):
|
||||
self.mm = mm
|
||||
self.conn = conn
|
||||
self.auto_withdraw = (
|
||||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
)
|
||||
logging.warning('model not whitelisted!, skip...')
|
||||
return
|
||||
|
||||
self.account = config['account']
|
||||
# if blacklist contains model skip
|
||||
if (
|
||||
len(config.model_blacklist) > 0
|
||||
and
|
||||
model in config.model_blacklist
|
||||
):
|
||||
logging.warning('model not blacklisted!, skip...')
|
||||
return
|
||||
|
||||
self.non_compete = set()
|
||||
if 'non_compete' in config:
|
||||
self.non_compete = set(config['non_compete'])
|
||||
results = [res['request_id'] for res in conn._tables['results']]
|
||||
|
||||
self.model_whitelist = set()
|
||||
if 'model_whitelist' in config:
|
||||
self.model_whitelist = set(config['model_whitelist'])
|
||||
# if worker already produced a result for this request
|
||||
if rid in results:
|
||||
logging.info(f'worker already submitted a result for request #{rid}, skip...')
|
||||
return
|
||||
|
||||
self.model_blacklist = set()
|
||||
if 'model_blacklist' in config:
|
||||
self.model_blacklist = set(config['model_blacklist'])
|
||||
statuses = conn._tables['requests'][rid]
|
||||
|
||||
self.backend = 'sync-on-thread'
|
||||
if 'backend' in config:
|
||||
self.backend = config['backend']
|
||||
# skip if workers in non_compete already on it
|
||||
competitors = set((status['worker'] for status in statuses))
|
||||
if bool(config.non_compete & competitors):
|
||||
logging.info('worker in configured non_compete list already working on request, skip...')
|
||||
return
|
||||
|
||||
self._snap = {
|
||||
'queue': [],
|
||||
'requests': {},
|
||||
'my_results': []
|
||||
}
|
||||
# resolve the ipfs hashes into the actual data behind them
|
||||
inputs = []
|
||||
raw_inputs = req['binary_data'].split(',')
|
||||
if raw_inputs:
|
||||
logging.info(f'fetching IPFS inputs: {raw_inputs}')
|
||||
|
||||
self._benchmark = []
|
||||
self._last_benchmark = None
|
||||
self._last_generation_ts = None
|
||||
retry = 3
|
||||
for _input in req['binary_data'].split(','):
|
||||
if _input:
|
||||
for r in range(retry):
|
||||
try:
|
||||
# user `GPUConnector` to IO with
|
||||
# storage layer to seed the compute
|
||||
# task.
|
||||
img = await conn.get_input_data(_input)
|
||||
inputs.append(img)
|
||||
logging.info(f'retrieved {_input}!')
|
||||
break
|
||||
|
||||
def _get_benchmark_speed(self) -> float:
|
||||
if not self._last_benchmark:
|
||||
return 0
|
||||
except BaseException:
|
||||
logging.exception(
|
||||
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
|
||||
)
|
||||
|
||||
start = self._last_benchmark[0]
|
||||
end = self._last_benchmark[-1]
|
||||
# compute unique request hash used on submit
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
+
|
||||
req['body']
|
||||
+
|
||||
req['binary_data']
|
||||
)
|
||||
logging.debug(f'hashing: {hash_str}')
|
||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
elapsed = end - start
|
||||
its = len(self._last_benchmark)
|
||||
speed = its / elapsed
|
||||
params = body['params']
|
||||
total_step = params['step'] if 'step' in params else 1
|
||||
model = body['params']['model']
|
||||
mode = body['method']
|
||||
|
||||
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
|
||||
# TODO: validate request
|
||||
|
||||
return speed
|
||||
resp = await conn.begin_work(rid)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info('begin_work error, probably being worked on already... skip.')
|
||||
return
|
||||
|
||||
async def should_cancel_work(self, request_id: int):
|
||||
self._benchmark.append(time.time())
|
||||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._snap['requests'][request_id]
|
||||
if status['worker'] != self.account
|
||||
])
|
||||
return bool(self.non_compete & competitors)
|
||||
with maybe_load_model(model, mode) as model:
|
||||
try:
|
||||
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
output_type = body['params']['output_type']
|
||||
|
||||
output = None
|
||||
output_hash = None
|
||||
match config.backend:
|
||||
case 'sync-on-thread':
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
compute_one,
|
||||
model,
|
||||
rid,
|
||||
mode, params,
|
||||
inputs=inputs,
|
||||
should_cancel=conn.should_cancel_work,
|
||||
)
|
||||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(
|
||||
f'Unsupported backend {config.backend}'
|
||||
)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
||||
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
|
||||
|
||||
await conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
await maybe_update_tui_balance(conn)
|
||||
|
||||
|
||||
async def snap_updater_task(self):
|
||||
except BaseException as err:
|
||||
if 'network cancel' not in str(err):
|
||||
logging.exception('Failed to serve model request !?\n')
|
||||
|
||||
if rid in conn._tables['requests']:
|
||||
await conn.cancel_work(rid, 'reason not provided')
|
||||
|
||||
|
||||
async def dgpu_serve_forever(config: Config, conn: NetConnector):
|
||||
await maybe_update_tui_balance(conn)
|
||||
|
||||
last_poll_idx = -1
|
||||
try:
|
||||
while True:
|
||||
self._snap = await self.conn.get_full_queue_snapshot()
|
||||
await trio.sleep(1)
|
||||
await conn.wait_data_update()
|
||||
if conn.poll_index == last_poll_idx:
|
||||
await trio.sleep(config.poll_time)
|
||||
continue
|
||||
|
||||
async def generate_api(self):
|
||||
app = Quart(__name__)
|
||||
last_poll_idx = conn.poll_index
|
||||
|
||||
@app.route('/')
|
||||
async def health():
|
||||
return jsonify(
|
||||
account=self.account,
|
||||
version=VERSION,
|
||||
last_generation_ts=self._last_generation_ts,
|
||||
last_generation_speed=self._get_benchmark_speed()
|
||||
queue = conn._tables['queue']
|
||||
|
||||
random.shuffle(queue)
|
||||
queue = sorted(
|
||||
queue,
|
||||
key=lambda req: convert_reward_to_int(req['reward']),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return app
|
||||
if len(queue) > 0:
|
||||
await maybe_serve_one(config, conn, queue[0])
|
||||
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
while True:
|
||||
if self.auto_withdraw:
|
||||
await self.conn.maybe_withdraw_all()
|
||||
|
||||
queue = self._snap['queue']
|
||||
|
||||
random.shuffle(queue)
|
||||
queue = sorted(
|
||||
queue,
|
||||
key=lambda req: convert_reward_to_int(req['reward']),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for req in queue:
|
||||
rid = req['id']
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
model = body['params']['model']
|
||||
|
||||
# if model not known
|
||||
if model not in MODELS:
|
||||
logging.warning(f'Unknown model {model}')
|
||||
continue
|
||||
|
||||
# if whitelist enabled and model not in it continue
|
||||
if (len(self.model_whitelist) > 0 and
|
||||
not model in self.model_whitelist):
|
||||
continue
|
||||
|
||||
# if blacklist contains model skip
|
||||
if model in self.model_blacklist:
|
||||
continue
|
||||
|
||||
my_results = [res['id'] for res in self._snap['my_results']]
|
||||
if rid not in my_results and rid in self._snap['requests']:
|
||||
statuses = self._snap['requests'][rid]
|
||||
|
||||
if len(statuses) == 0:
|
||||
binary, input_type = await self.conn.get_input_data(req['binary_data'])
|
||||
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
+
|
||||
req['body']
|
||||
+
|
||||
req['binary_data']
|
||||
)
|
||||
logging.info(f'hashing: {hash_str}')
|
||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
|
||||
# TODO: validate request
|
||||
|
||||
# perform work
|
||||
logging.info(f'working on {body}')
|
||||
|
||||
resp = await self.conn.begin_work(rid)
|
||||
if 'code' in resp:
|
||||
logging.info(f'probably being worked on already... skip.')
|
||||
|
||||
else:
|
||||
try:
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
output_type = body['params']['output_type']
|
||||
|
||||
output = None
|
||||
output_hash = None
|
||||
match self.backend:
|
||||
case 'sync-on-thread':
|
||||
self.mm._should_cancel = self.should_cancel_work
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
self.mm.compute_one,
|
||||
rid,
|
||||
body['method'], body['params'],
|
||||
input_type=input_type,
|
||||
binary=binary
|
||||
)
|
||||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported backend {self.backend}')
|
||||
self._last_generation_ts = datetime.now().isoformat()
|
||||
self._last_benchmark = self._benchmark
|
||||
self._benchmark = []
|
||||
|
||||
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
|
||||
|
||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
await self.conn.cancel_work(rid, str(e))
|
||||
|
||||
finally:
|
||||
break
|
||||
|
||||
else:
|
||||
logging.info(f'request {rid} already beign worked on, skip...')
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
|
||||
class DGPUComputeError(BaseException):
|
||||
...
|
||||
|
|
|
@ -1,100 +1,123 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from functools import partial
|
||||
|
||||
import asks
|
||||
import trio
|
||||
import leap
|
||||
import anyio
|
||||
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
import httpx
|
||||
import outcome
|
||||
from PIL import Image
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import Checksum256, Name, asset_from_str
|
||||
from skynet.constants import DEFAULT_IPFS_DOMAIN
|
||||
from leap.protocol import Asset
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.config import DgpuConfig as Config
|
||||
from skynet.constants import (
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
GPU_CONTRACT_ABI,
|
||||
)
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
from skynet.ipfs import (
|
||||
AsyncIPFSHTTP,
|
||||
get_ipfs_file,
|
||||
)
|
||||
|
||||
|
||||
REQUEST_UPDATE_TIME = 3
|
||||
REQUEST_UPDATE_TIME: int = 3
|
||||
|
||||
|
||||
async def failable(fn: partial, ret_fail=None):
|
||||
try:
|
||||
return await fn()
|
||||
o = await outcome.acapture(fn)
|
||||
match o:
|
||||
case outcome.Error(error=(
|
||||
OSError() |
|
||||
json.JSONDecodeError() |
|
||||
anyio.BrokenResourceError() |
|
||||
httpx.ConnectError() |
|
||||
httpx.ConnectTimeout() |
|
||||
httpx.ReadError() |
|
||||
httpx.ReadTimeout() |
|
||||
leap.errors.TransactionPushError()
|
||||
)):
|
||||
return ret_fail
|
||||
|
||||
except (
|
||||
OSError,
|
||||
json.JSONDecodeError,
|
||||
asks.errors.RequestTimeout,
|
||||
asks.errors.BadHttpResponse,
|
||||
anyio.BrokenResourceError
|
||||
):
|
||||
return ret_fail
|
||||
case _:
|
||||
return o.unwrap()
|
||||
|
||||
|
||||
class SkynetGPUConnector:
|
||||
class NetConnector:
|
||||
'''
|
||||
An API for connecting to and conducting various "high level"
|
||||
network-service operations in the skynet.
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.account = Name(config['account'])
|
||||
self.permission = config['permission']
|
||||
self.key = config['key']
|
||||
- skynet user account creds
|
||||
- hyperion API
|
||||
- IPFs client
|
||||
- CLEOS client
|
||||
|
||||
self.node_url = config['node_url']
|
||||
self.hyperion_url = config['hyperion_url']
|
||||
'''
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.cleos = CLEOS(endpoint=config.node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
|
||||
self.cleos = CLEOS(
|
||||
None, None, self.node_url, remote=self.node_url)
|
||||
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
|
||||
|
||||
self.ipfs_gateway_url = None
|
||||
if 'ipfs_gateway_url' in config:
|
||||
self.ipfs_gateway_url = config['ipfs_gateway_url']
|
||||
self.ipfs_url = config['ipfs_url']
|
||||
# poll_index is used to detect stale data
|
||||
self.poll_index = 0
|
||||
self._tables = {
|
||||
'queue': [],
|
||||
'requests': {},
|
||||
'results': []
|
||||
}
|
||||
self._data_event = trio.Event()
|
||||
|
||||
self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
|
||||
|
||||
self.ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
if 'ipfs_domain' in config:
|
||||
self.ipfs_domain = config['ipfs_domain']
|
||||
|
||||
self._wip_requests = {}
|
||||
|
||||
# blockchain helpers
|
||||
|
||||
async def get_work_requests_last_hour(self):
|
||||
logging.info('get_work_requests_last_hour')
|
||||
return await failable(
|
||||
rows = await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'queue',
|
||||
'gpu.scd', 'gpu.scd', 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
lower_bound=int(time.time()) - 3600
|
||||
), ret_fail=[])
|
||||
|
||||
logging.info(f'found {len(rows)} requests on queue')
|
||||
return rows
|
||||
|
||||
async def get_status_by_request_id(self, request_id: int):
|
||||
logging.info('get_status_by_request_id')
|
||||
return await failable(
|
||||
rows = await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', request_id, 'status'), ret_fail=[])
|
||||
'gpu.scd', request_id, 'status'), ret_fail=[])
|
||||
|
||||
logging.info(f'found status for workers: {[r["worker"] for r in rows]}')
|
||||
return rows
|
||||
|
||||
async def get_global_config(self):
|
||||
logging.info('get_global_config')
|
||||
rows = await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'config'))
|
||||
'gpu.scd', 'gpu.scd', 'config'))
|
||||
|
||||
if rows:
|
||||
return rows[0]
|
||||
cfg = rows[0]
|
||||
logging.info(f'config found: {cfg}')
|
||||
return cfg
|
||||
else:
|
||||
logging.error('global config not found, is the contract initialized?')
|
||||
return None
|
||||
|
||||
async def get_worker_balance(self):
|
||||
|
@ -102,33 +125,29 @@ class SkynetGPUConnector:
|
|||
rows = await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'users',
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
lower_bound=self.config.account,
|
||||
upper_bound=self.config.account
|
||||
))
|
||||
|
||||
if rows:
|
||||
return rows[0]['balance']
|
||||
b = rows[0]['balance']
|
||||
logging.info(f'balance: {b}')
|
||||
return b
|
||||
else:
|
||||
logging.info('no balance info found')
|
||||
return None
|
||||
|
||||
async def get_competitors_for_req(self, request_id: int) -> set:
|
||||
competitors = [
|
||||
status['worker']
|
||||
for status in
|
||||
(await self.get_status_by_request_id(request_id))
|
||||
if status['worker'] != self.account
|
||||
]
|
||||
logging.info(f'competitors: {competitors}')
|
||||
return set(competitors)
|
||||
|
||||
|
||||
async def get_full_queue_snapshot(self):
|
||||
'''
|
||||
Get a "snapshot" of current contract table state
|
||||
|
||||
'''
|
||||
snap = {
|
||||
'requests': {},
|
||||
'my_results': []
|
||||
'results': []
|
||||
}
|
||||
|
||||
snap['queue'] = await self.get_work_requests_last_hour()
|
||||
|
@ -137,44 +156,86 @@ class SkynetGPUConnector:
|
|||
d[key] = await fn(*args, **kwargs)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(_run_and_save, snap, 'my_results', self.find_my_results)
|
||||
n.start_soon(_run_and_save, snap, 'results', self.find_results)
|
||||
for req in snap['queue']:
|
||||
n.start_soon(
|
||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||
|
||||
|
||||
maybe_update_tui(lambda tui: tui.network_update(snap))
|
||||
|
||||
return snap
|
||||
|
||||
async def wait_data_update(self):
|
||||
await self._data_event.wait()
|
||||
|
||||
async def iter_poll_update(self, poll_time: float):
|
||||
'''
|
||||
Long running task, polls gpu contract tables latest table rows,
|
||||
awakes any self._data_event waiters
|
||||
|
||||
'''
|
||||
while True:
|
||||
start_time = time.time()
|
||||
self._tables = await self.get_full_queue_snapshot()
|
||||
elapsed = time.time() - start_time
|
||||
self._data_event.set()
|
||||
await trio.sleep(max(poll_time - elapsed, 0.1))
|
||||
self._data_event = trio.Event()
|
||||
self.poll_index += 1
|
||||
|
||||
async def should_cancel_work(self, request_id: int) -> bool:
|
||||
logging.info('should cancel work?')
|
||||
if request_id not in self._tables['requests']:
|
||||
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
|
||||
return True
|
||||
|
||||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._tables['requests'][request_id]
|
||||
if status['worker'] != self.config.account
|
||||
])
|
||||
logging.info(f'competitors: {competitors}')
|
||||
should_cancel = bool(self.config.non_compete & competitors)
|
||||
logging.info(f'cancel: {should_cancel}')
|
||||
return should_cancel
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
logging.info('begin_work')
|
||||
'''
|
||||
Publish to the bc that the worker is beginning a model-computation
|
||||
step.
|
||||
|
||||
'''
|
||||
logging.info(f'begin_work on #{request_id}')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'workbegin',
|
||||
{
|
||||
'worker': self.account,
|
||||
list({
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'max_workers': 2
|
||||
},
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
}.values()),
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
async def cancel_work(self, request_id: int, reason: str):
|
||||
logging.info('cancel_work')
|
||||
logging.info(f'cancel_work on #{request_id}')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'workcancel',
|
||||
{
|
||||
'worker': self.account,
|
||||
list({
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'reason': reason
|
||||
},
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
}.values()),
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -189,29 +250,30 @@ class SkynetGPUConnector:
|
|||
await failable(
|
||||
partial(
|
||||
self.cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'withdraw',
|
||||
{
|
||||
'user': self.account,
|
||||
'quantity': asset_from_str(balance)
|
||||
},
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
list({
|
||||
'user': self.config.account,
|
||||
'quantity': Asset.from_str(balance)
|
||||
}.values()),
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
async def find_my_results(self):
|
||||
logging.info('find_my_results')
|
||||
return await failable(
|
||||
async def find_results(self):
|
||||
logging.info('find_results')
|
||||
rows = await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'results',
|
||||
'gpu.scd', 'gpu.scd', 'results',
|
||||
index_position=4,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
lower_bound=self.config.account,
|
||||
upper_bound=self.config.account
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
async def submit_work(
|
||||
self,
|
||||
|
@ -220,21 +282,21 @@ class SkynetGPUConnector:
|
|||
result_hash: str,
|
||||
ipfs_hash: str
|
||||
):
|
||||
logging.info('submit_work')
|
||||
logging.info(f'submit_work #{request_id}')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.a_push_action,
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'submit',
|
||||
{
|
||||
'worker': self.account,
|
||||
list({
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'request_hash': Checksum256(request_hash),
|
||||
'result_hash': Checksum256(result_hash),
|
||||
'request_hash': request_hash,
|
||||
'result_hash': result_hash,
|
||||
'ipfs_hash': ipfs_hash
|
||||
},
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
}.values()),
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -253,60 +315,32 @@ class SkynetGPUConnector:
|
|||
case _:
|
||||
raise ValueError(f'Unsupported output type: {typ}')
|
||||
|
||||
if self.ipfs_gateway_url:
|
||||
# check peer connections, reconnect to skynet gateway if not
|
||||
gateway_id = Path(self.ipfs_gateway_url).name
|
||||
peers = await self.ipfs_client.peers()
|
||||
if gateway_id not in [p['Peer'] for p in peers]:
|
||||
await self.ipfs_client.connect(self.ipfs_gateway_url)
|
||||
|
||||
file_info = await self.ipfs_client.add(Path(target_file))
|
||||
file_cid = file_info['Hash']
|
||||
logging.info(f'added file to ipfs, CID: {file_cid}')
|
||||
|
||||
await self.ipfs_client.pin(file_cid)
|
||||
logging.info(f'pinned {file_cid}')
|
||||
|
||||
return file_cid
|
||||
|
||||
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
|
||||
input_type = 'none'
|
||||
async def get_input_data(self, ipfs_hash: str) -> Image:
|
||||
'''
|
||||
Retrieve an input (image) from the IPFs layer.
|
||||
|
||||
if ipfs_hash == '':
|
||||
return b'', input_type
|
||||
Normally used to retreive seed (visual) content previously
|
||||
generated/validated by the network to be fed to some
|
||||
consuming AI model.
|
||||
|
||||
results = {}
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
||||
'''
|
||||
link = f'https://{self.config.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
async def get_and_set_results(link: str):
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
logging.info(f'got response from {link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
else:
|
||||
try:
|
||||
# attempt to decode as image
|
||||
results[link] = Image.open(io.BytesIO(res.raw))
|
||||
input_type = 'png'
|
||||
n.cancel_scope.cancel()
|
||||
# attempt to decode as image
|
||||
input_data = Image.open(io.BytesIO(res.read()))
|
||||
logging.info('decoded as image successfully')
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
n.start_soon(
|
||||
get_and_set_results, ipfs_link)
|
||||
n.start_soon(
|
||||
get_and_set_results, ipfs_link_legacy)
|
||||
|
||||
input_data = None
|
||||
if ipfs_link_legacy in results:
|
||||
input_data = results[ipfs_link_legacy]
|
||||
|
||||
if ipfs_link in results:
|
||||
input_data = results[ipfs_link]
|
||||
|
||||
if input_data == None:
|
||||
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
|
||||
|
||||
return input_data, input_type
|
||||
return input_data
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
|
||||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel
|
||||
)
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
__model = {
|
||||
'name': 'black-forest-labs/FLUX.1-schnell'
|
||||
}
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
qonfig = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
params = {
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'cache_dir': cache_dir,
|
||||
'device_map': 'balanced',
|
||||
'max_memory': {'cpu': '10GiB', 0: '11GiB'}
|
||||
# 'max_memory': {0: '11GiB'}
|
||||
}
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
'black-forest-labs/FLUX.1-schnell',
|
||||
subfolder="text_encoder_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=qonfig
|
||||
)
|
||||
params['text_encoder_2'] = text_encoder
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
return pipe
|
|
@ -0,0 +1,54 @@
|
|||
import torch
|
||||
|
||||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
FluxFillPipeline,
|
||||
FluxTransformer2DModel
|
||||
)
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
|
||||
__model = {
|
||||
'name': 'black-forest-labs/FLUX.1-Fill-dev'
|
||||
}
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
qonfig = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
params = {
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'cache_dir': cache_dir,
|
||||
'device_map': 'balanced',
|
||||
'max_memory': {'cpu': '10GiB', 0: '11GiB'}
|
||||
# 'max_memory': {0: '11GiB'}
|
||||
}
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
'sayakpaul/FLUX.1-Fill-dev-nf4',
|
||||
subfolder="text_encoder_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=qonfig
|
||||
)
|
||||
params['text_encoder_2'] = text_encoder
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
'sayakpaul/FLUX.1-Fill-dev-nf4',
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=qonfig
|
||||
)
|
||||
params['transformer'] = transformer
|
||||
|
||||
pipe = FluxFillPipeline.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
return pipe
|
|
@ -0,0 +1,42 @@
|
|||
import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import msgspec
|
||||
|
||||
|
||||
__model = {
|
||||
'name': 'skygpu/txt2img-mocker'
|
||||
}
|
||||
|
||||
class MockPipelineResult(msgspec.Struct):
|
||||
images: list[Image]
|
||||
|
||||
class MockPipeline:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
*args,
|
||||
num_inference_steps: int = 3,
|
||||
callback=None,
|
||||
mock_step_time: float = 0.1,
|
||||
**kwargs
|
||||
):
|
||||
for i in range(num_inference_steps):
|
||||
time.sleep(mock_step_time)
|
||||
if callback:
|
||||
callback(i+1)
|
||||
|
||||
img = Image.new('RGB', (1, 1), color='green')
|
||||
|
||||
return MockPipelineResult(images=[img])
|
||||
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
):
|
||||
return MockPipeline()
|
|
@ -0,0 +1,211 @@
|
|||
import json
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
import urwid
|
||||
|
||||
from skynet.config import DgpuConfig as Config
|
||||
|
||||
|
||||
class WorkerMonitor:
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
self.header_info = {}
|
||||
|
||||
self.palette = [
|
||||
('headerbar', 'white', 'dark blue'),
|
||||
('request_row', 'white', 'dark gray'),
|
||||
('worker_row', 'light gray', 'black'),
|
||||
('progress_normal', 'black', 'light gray'),
|
||||
('progress_complete', 'black', 'dark green'),
|
||||
('body', 'white', 'black'),
|
||||
]
|
||||
|
||||
# --- Top bar (header) ---
|
||||
worker_name = self.header_info.get('left', "unknown")
|
||||
balance = self.header_info.get('right', "balance: unknown")
|
||||
|
||||
self.worker_name_widget = urwid.Text(worker_name)
|
||||
self.balance_widget = urwid.Text(balance, align='right')
|
||||
|
||||
header = urwid.Columns([self.worker_name_widget, self.balance_widget])
|
||||
header_attr = urwid.AttrMap(header, 'headerbar')
|
||||
|
||||
# --- Body (List of requests) ---
|
||||
self.body_listbox = self._create_listbox_body(self.requests)
|
||||
|
||||
# --- Bottom bar (progress) ---
|
||||
self.status_text = urwid.Text("Request: none", align='left')
|
||||
self.progress_bar = urwid.ProgressBar(
|
||||
'progress_normal',
|
||||
'progress_complete',
|
||||
current=0,
|
||||
done=100
|
||||
)
|
||||
|
||||
footer_cols = urwid.Columns([
|
||||
('fixed', 20, self.status_text),
|
||||
self.progress_bar,
|
||||
])
|
||||
|
||||
# Build the main frame
|
||||
frame = urwid.Frame(
|
||||
self.body_listbox,
|
||||
header=header_attr,
|
||||
footer=footer_cols
|
||||
)
|
||||
|
||||
# Set up the main loop with Trio
|
||||
self.event_loop = urwid.TrioEventLoop()
|
||||
self.main_loop = urwid.MainLoop(
|
||||
frame,
|
||||
palette=self.palette,
|
||||
event_loop=self.event_loop,
|
||||
unhandled_input=self._exit_on_q
|
||||
)
|
||||
|
||||
def _create_listbox_body(self, requests):
|
||||
"""
|
||||
Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
|
||||
"""
|
||||
widgets = self._build_request_widgets(requests)
|
||||
walker = urwid.SimpleFocusListWalker(widgets)
|
||||
return urwid.ListBox(walker)
|
||||
|
||||
def _build_request_widgets(self, requests):
|
||||
"""
|
||||
Build a list of Urwid widgets (one row per request + per worker).
|
||||
"""
|
||||
row_widgets = []
|
||||
|
||||
for req in requests:
|
||||
# Build a columns widget for the request row
|
||||
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
||||
columns = urwid.Columns([
|
||||
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
||||
('weight', 3, urwid.Text(req['model'])),
|
||||
('weight', 3, urwid.Text(prompt)),
|
||||
('fixed', 13, urwid.Text(req['user'])),
|
||||
('fixed', 13, urwid.Text(req['reward'])),
|
||||
], dividechars=1)
|
||||
|
||||
# Wrap the columns with an attribute map for coloring
|
||||
request_row = urwid.AttrMap(columns, 'request_row')
|
||||
row_widgets.append(request_row)
|
||||
|
||||
# Then add each worker in its own line below
|
||||
for w in req["workers"]:
|
||||
worker_line = urwid.Text(f" {w}")
|
||||
worker_row = urwid.AttrMap(worker_line, 'worker_row')
|
||||
row_widgets.append(worker_row)
|
||||
|
||||
# Optional blank line after each request
|
||||
row_widgets.append(urwid.Text(""))
|
||||
|
||||
return row_widgets
|
||||
|
||||
def _exit_on_q(self, key):
|
||||
"""Exit the TUI on 'q' or 'Q'."""
|
||||
if key in ('q', 'Q'):
|
||||
raise urwid.ExitMainLoop()
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Run the TUI in an async context (Trio).
|
||||
This method blocks until the user quits (pressing q/Q).
|
||||
"""
|
||||
with self.main_loop.start():
|
||||
await self.event_loop.run_async()
|
||||
|
||||
raise urwid.ExitMainLoop()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Public Methods to Update Various Parts of the UI
|
||||
# -------------------------------------------------------------------------
|
||||
def set_status(self, status: str):
|
||||
self.status_text.set_text(status)
|
||||
|
||||
def set_progress(self, current, done=None):
|
||||
"""
|
||||
Update the bottom progress bar.
|
||||
- `current`: new current progress value (int).
|
||||
- `done`: max progress value (int). If None, we don’t change it.
|
||||
"""
|
||||
if done is not None:
|
||||
self.progress_bar.done = done
|
||||
|
||||
self.progress_bar.current = current
|
||||
|
||||
pct = 0
|
||||
if self.progress_bar.done != 0:
|
||||
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
|
||||
|
||||
def update_requests(self, new_requests):
|
||||
"""
|
||||
Replace the data in the existing ListBox with new request widgets.
|
||||
"""
|
||||
new_widgets = self._build_request_widgets(new_requests)
|
||||
self.body_listbox.body[:] = new_widgets # replace content of the list walker
|
||||
|
||||
def set_header_text(self, new_worker_name=None, new_balance=None):
|
||||
"""
|
||||
Update the text in the header bar for worker name and/or balance.
|
||||
"""
|
||||
if new_worker_name is not None:
|
||||
self.worker_name_widget.set_text(new_worker_name)
|
||||
if new_balance is not None:
|
||||
self.balance_widget.set_text(new_balance)
|
||||
|
||||
def network_update(self, snapshot: dict):
|
||||
queue = [
|
||||
{
|
||||
**r,
|
||||
**(json.loads(r['body'])['params']),
|
||||
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
|
||||
}
|
||||
for r in snapshot['queue']
|
||||
]
|
||||
self.update_requests(queue)
|
||||
|
||||
|
||||
def setup_logging_for_tui(config: Config):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
level = getattr(logging, config.log_level.upper(), logging.WARNING)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
|
||||
fh = logging.FileHandler(config.log_file)
|
||||
fh.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
fh.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(fh)
|
||||
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
_tui: WorkerMonitor | None = None
|
||||
def init_tui(config: Config):
|
||||
global _tui
|
||||
assert not _tui
|
||||
setup_logging_for_tui(config)
|
||||
_tui = WorkerMonitor()
|
||||
return _tui
|
||||
|
||||
|
||||
def maybe_update_tui(fn):
|
||||
global _tui
|
||||
if _tui:
|
||||
fn(_tui)
|
||||
|
||||
|
||||
async def maybe_update_tui_async(fn):
|
||||
global _tui
|
||||
if _tui:
|
||||
await fn(_tui)
|
|
@ -0,0 +1,323 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import diffusers
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
AutoPipelineForText2Image,
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from huggingface_hub import login, hf_hub_download
|
||||
|
||||
from skynet.config import load_skynet_toml
|
||||
from skynet.constants import MODELS
|
||||
|
||||
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
||||
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
|
||||
try:
|
||||
import torchvision.transforms.functional_tensor # noqa: F401
|
||||
except ImportError:
|
||||
try:
|
||||
import torchvision.transforms.functional as functional
|
||||
sys.modules["torchvision.transforms.functional_tensor"] = functional
|
||||
except ImportError:
|
||||
pass # shrug...
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
|
||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
||||
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
||||
return np.asarray(img)
|
||||
|
||||
|
||||
def convert_from_bytes_to_img(raw: bytes) -> Image:
|
||||
return Image.open(io.BytesIO(raw))
|
||||
|
||||
|
||||
def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
|
||||
byte_arr = io.BytesIO()
|
||||
image.save(byte_arr, format=fmt)
|
||||
return byte_arr.getvalue()
|
||||
|
||||
|
||||
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
|
||||
w, h = image.size
|
||||
if w > max_w or h > max_h:
|
||||
image.thumbnail((max_w, max_h))
|
||||
|
||||
return image.convert('RGB')
|
||||
|
||||
def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
||||
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
|
||||
|
||||
|
||||
class DummyPB:
|
||||
def update(self):
|
||||
...
|
||||
|
||||
@torch.compiler.disable
|
||||
@contextmanager
|
||||
def dummy_progress_bar(*args, **kwargs):
|
||||
yield DummyPB()
|
||||
|
||||
|
||||
def monkey_patch_pipeline_disable_progress_bar(pipe):
|
||||
pipe.progress_bar = dummy_progress_bar
|
||||
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
diffusers.utils.logging.disable_progress_bar()
|
||||
|
||||
logging.info(f'pipeline_for {model} {mode}')
|
||||
# assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# full determinism
|
||||
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
model_info = MODELS[model]
|
||||
shortname = model_info.short
|
||||
|
||||
# disable for compat with "diffuse" method
|
||||
# assert mode in model_info.tags
|
||||
|
||||
# default to checking if custom pipeline exist and return that if not, attempt generic
|
||||
try:
|
||||
normalized_shortname = shortname.replace('-', '_')
|
||||
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
|
||||
assert custom_pipeline.__model['name'] == model
|
||||
pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
||||
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||
return pipe
|
||||
|
||||
except ImportError:
|
||||
logging.info(f'didn\'t find a custom pipeline file for {shortname}')
|
||||
|
||||
|
||||
req_mem = model_info.mem
|
||||
|
||||
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
||||
mem_gb *= mem_fraction
|
||||
over_mem = mem_gb < req_mem
|
||||
if over_mem:
|
||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||
|
||||
params = {
|
||||
'torch_dtype': torch.float16,
|
||||
'cache_dir': cache_dir,
|
||||
'variant': 'fp16',
|
||||
}
|
||||
|
||||
match shortname:
|
||||
case 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
params['safety_checker'] = None
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
||||
pipe_class = DiffusionPipeline
|
||||
match mode:
|
||||
case 'inpaint':
|
||||
pipe_class = AutoPipelineForInpainting
|
||||
|
||||
case 'img2img':
|
||||
pipe_class = AutoPipelineForImage2Image
|
||||
|
||||
case 'txt2img':
|
||||
pipe_class = AutoPipelineForText2Image
|
||||
|
||||
pipe = pipe_class.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
# pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if over_mem:
|
||||
if mode == 'txt2img':
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
else:
|
||||
# if sys.version_info[1] < 11:
|
||||
# # torch.compile only supported on python < 3.11
|
||||
# pipe.unet = torch.compile(
|
||||
# pipe.unet, mode='reduce-overhead', fullgraph=True)
|
||||
|
||||
pipe = pipe.to('cuda')
|
||||
|
||||
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def txt2img(
|
||||
hf_token: str,
|
||||
model: str = list(MODELS.keys())[-1],
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
output: str = 'output.png',
|
||||
width: int = 512, height: int = 512,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, 'txt2img')
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def img2img(
|
||||
hf_token: str,
|
||||
model: str = list(MODELS.keys())[-2],
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
strength: float = 1.0,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, 'img2img')
|
||||
|
||||
model_info = MODELS[model]
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=input_img,
|
||||
strength=strength,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def inpaint(
|
||||
hf_token: str,
|
||||
model: str = list(MODELS.keys())[-3],
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
mask_path: str = 'mask.png',
|
||||
output: str = 'output.png',
|
||||
strength: float = 1.0,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, 'inpaint')
|
||||
|
||||
model_info = MODELS[model]
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h)
|
||||
|
||||
with open(mask_path, 'rb') as mask_file:
|
||||
mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h)
|
||||
|
||||
var_params = {}
|
||||
if 'flux' not in model.lower():
|
||||
var_params['strength'] = strength
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=input_img,
|
||||
mask_image=mask_img,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed),
|
||||
**var_params
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def init_upscaler():
|
||||
config = load_skynet_toml().dgpu
|
||||
model_path = hf_hub_download(
|
||||
'leonelhs/realesrgan',
|
||||
'RealESRGAN_x4plus.pth',
|
||||
token=config.hf_token,
|
||||
cache_dir=config.hf_home
|
||||
)
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
model_path=model_path,
|
||||
dni_weight=None,
|
||||
model=RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4
|
||||
),
|
||||
half=True
|
||||
)
|
||||
|
||||
def upscale(
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png'
|
||||
):
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
|
||||
upscaler = init_upscaler()
|
||||
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
image.save(output)
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import random
|
||||
|
||||
from ..constants import *
|
||||
|
@ -39,7 +37,7 @@ def validate_user_config_request(req: str):
|
|||
case 'model' | 'algo':
|
||||
attr = 'model'
|
||||
val = params[2]
|
||||
shorts = [model_info['short'] for model_info in MODELS.values()]
|
||||
shorts = [model_info.short for model_info in MODELS.values()]
|
||||
if val not in shorts:
|
||||
raise ConfigUnknownAlgorithm(f'no model named {val}')
|
||||
|
||||
|
@ -112,20 +110,10 @@ def validate_user_config_request(req: str):
|
|||
|
||||
|
||||
def perform_auto_conf(config: dict) -> dict:
|
||||
model = config['model']
|
||||
prefered_size_w = 512
|
||||
prefered_size_h = 512
|
||||
|
||||
if 'xl' in model:
|
||||
prefered_size_w = 1024
|
||||
prefered_size_h = 1024
|
||||
|
||||
else:
|
||||
prefered_size_w = 512
|
||||
prefered_size_h = 512
|
||||
model = MODELS[config['model']]
|
||||
|
||||
config['step'] = random.randint(20, 35)
|
||||
config['width'] = prefered_size_w
|
||||
config['height'] = prefered_size_h
|
||||
config['width'] = model.size.w
|
||||
config['height'] = model.size.h
|
||||
|
||||
return config
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from json import JSONDecodeError
|
||||
import random
|
||||
import logging
|
||||
|
@ -8,16 +6,25 @@ import asyncio
|
|||
from decimal import Decimal
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from contextlib import ExitStack, AsyncExitStack
|
||||
from contextlib import (
|
||||
ExitStack,
|
||||
AsyncExitStack,
|
||||
)
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import Name, asset_from_str, collect_stdout
|
||||
from leap.sugar import (
|
||||
Name,
|
||||
asset_from_str,
|
||||
collect_stdout,
|
||||
)
|
||||
from leap.hyperion import HyperionAPI
|
||||
# from telebot.types import InputMediaPhoto
|
||||
|
||||
import discord
|
||||
import requests
|
||||
import io
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
from skynet.db import open_database_connection
|
||||
from skynet.ipfs import get_ipfs_file, AsyncIPFSHTTP
|
||||
|
@ -66,7 +73,7 @@ class SkynetDiscordFrontend:
|
|||
self.bot = DiscordBot(self)
|
||||
self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
self.hyperion = HyperionAPI(hyperion_url)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_node)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
||||
self._exit_stack = ExitStack()
|
||||
self._async_exit_stack = AsyncExitStack()
|
||||
|
@ -153,7 +160,7 @@ class SkynetDiscordFrontend:
|
|||
|
||||
reward = '20.0000 GPU'
|
||||
res = await self.cleos.a_push_action(
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
{
|
||||
'user': Name(self.account),
|
||||
|
@ -200,7 +207,7 @@ class SkynetDiscordFrontend:
|
|||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.account,
|
||||
filter='telos.gpu:submit',
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
|
@ -234,7 +241,7 @@ class SkynetDiscordFrontend:
|
|||
await message.edit(embed=embed)
|
||||
return False
|
||||
|
||||
tx_link = f'[**Your result on Skynet Explorer**](https://explorer.{DEFAULT_DOMAIN}/v2/explore/transaction/{tx_hash})'
|
||||
tx_link = f'[**Your result on Skynet Explorer**](https://{self.explorer_domain}/v2/explore/transaction/{tx_hash})'
|
||||
|
||||
msg_text += f'**request processed!**\n{tx_link}\n[{timestamp_pretty()}] *trying to download image...*\n '
|
||||
embed = discord.Embed(
|
||||
|
@ -264,7 +271,8 @@ class SkynetDiscordFrontend:
|
|||
results[link] = png_img
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
logging.warning(
|
||||
f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
tasks = [
|
||||
get_and_set_results(ipfs_link),
|
||||
|
@ -280,32 +288,35 @@ class SkynetDiscordFrontend:
|
|||
png_img = results[ipfs_link]
|
||||
|
||||
if not png_img:
|
||||
await self.update_status_message(
|
||||
status_msg,
|
||||
caption,
|
||||
reply_markup=build_redo_menu(),
|
||||
parse_mode='HTML'
|
||||
)
|
||||
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
|
||||
embed.add_field(
|
||||
name='Error', value=f'couldn\'t get ipfs hosted image [**here**]({ipfs_link})!')
|
||||
await message.edit(embed=embed, view=SkynetView(self))
|
||||
return True
|
||||
|
||||
# reword this function, may not need caption
|
||||
caption, embed = generate_reply_caption(
|
||||
user, params, tx_hash, worker, reward)
|
||||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||
|
||||
if not resp or resp.status_code != 200:
|
||||
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
|
||||
embed.add_field(name='Error', value=f'couldn\'t get ipfs hosted image [**here**]({ipfs_link})!')
|
||||
await message.edit(embed=embed, view=SkynetView(self))
|
||||
else:
|
||||
logging.info(f'success! sending generated image')
|
||||
await message.delete()
|
||||
if file_id: # img2img
|
||||
embed.set_thumbnail(
|
||||
url='https://ipfs.skygpu.net/ipfs/' + binary_data + '/image.png')
|
||||
embed.set_image(url=ipfs_link)
|
||||
await send(embed=embed, view=SkynetView(self))
|
||||
else: # txt2img
|
||||
embed.set_image(url=ipfs_link)
|
||||
logging.info(f'success! sending generated image')
|
||||
await message.delete()
|
||||
if file_id: # img2img
|
||||
embed.set_image(url=ipfs_link)
|
||||
orig_url = f'https://{self.ipfs_domain}/ipfs/' + binary_data
|
||||
res = requests.get(orig_url, stream=True)
|
||||
if res.status_code == 200:
|
||||
with io.BytesIO(res.content) as img:
|
||||
file = discord.File(img, filename='image.png')
|
||||
embed.set_thumbnail(url='attachment://image.png')
|
||||
await send(embed=embed, view=SkynetView(self), file=file)
|
||||
# orig_url = f'https://{self.ipfs_domain}/ipfs/' \
|
||||
# + binary_data + '/image.png'
|
||||
# embed.set_thumbnail(
|
||||
# url=orig_url)
|
||||
else:
|
||||
await send(embed=embed, view=SkynetView(self))
|
||||
else: # txt2img
|
||||
embed.set_image(url=ipfs_link)
|
||||
await send(embed=embed, view=SkynetView(self))
|
||||
|
||||
return True
|
||||
|
|
|
@ -44,7 +44,7 @@ class DiscordBot(commands.Bot):
|
|||
await channel.send('Skynet bot online', view=SkynetView(self.bot))
|
||||
# intro_msg = await channel.send('Welcome to the Skynet discord bot.\nSkynet is a decentralized compute layer, focused on supporting AI paradigms. Skynet leverages blockchain technology to manage work requests and fills. We are currently featuring image generation and support 11 different models. Get started with the /help command, or just click on some buttons. Here is an example command to generate an image:\n/txt2img a big red tractor in a giant field of corn')
|
||||
intro_msg = await channel.send("Welcome to Skynet's Discord Bot,\n\nSkynet operates as a decentralized compute layer, offering a wide array of support for diverse AI paradigms through the use of blockchain technology. Our present focus is image generation, powered by 11 distinct models.\n\nTo begin exploring, use the '/help' command or directly interact with the provided buttons. Here is an example command to generate an image:\n\n'/txt2img a big red tractor in a giant field of corn'")
|
||||
await intro_msg.pin()
|
||||
# await intro_msg.pin()
|
||||
|
||||
print("\n==============")
|
||||
print("Logged in as")
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
@ -42,6 +40,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
await ctx.reply(content=reply_txt, view=SkynetView(frontend))
|
||||
|
||||
bot.remove_command('help')
|
||||
|
||||
@bot.command(name='help', help='Responds with a help')
|
||||
async def help(ctx):
|
||||
splt_msg = ctx.message.content.split(' ')
|
||||
|
@ -62,7 +61,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
clean_cool_word = '\n'.join(CLEAN_COOL_WORDS)
|
||||
await ctx.send(content=f'```{clean_cool_word}```', view=SkynetView(frontend))
|
||||
|
||||
@bot.command(name='stats', help='See user statistics' )
|
||||
@bot.command(name='stats', help='See user statistics')
|
||||
async def user_stats(ctx):
|
||||
user = ctx.author
|
||||
|
||||
|
@ -96,9 +95,8 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
prompt = ' '.join(ctx.message.content.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await status_msg.edit(content=
|
||||
'Empty text prompt ignored.'
|
||||
)
|
||||
await status_msg.edit(content='Empty text prompt ignored.'
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
|
@ -209,14 +207,23 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
file_id = str(file.id)
|
||||
# file bytes
|
||||
image_raw = await file.read()
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > 512 or h > 512:
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail((512, 512))
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
# if w > 512 or h > 512:
|
||||
# logging.warning(f'user sent img of size {image.size}')
|
||||
# image.thumbnail((512, 512))
|
||||
# logging.warning(f'resized it to {image.size}')
|
||||
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
|
@ -228,9 +235,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
|
||||
logging.info(f'mid: {ctx.message.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
|
@ -240,8 +244,8 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_file=file_id,
|
||||
last_prompt=prompt,
|
||||
last_file=file_id,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
|
@ -254,8 +258,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
if success:
|
||||
await db_call('increment_generated', user.id)
|
||||
|
||||
|
||||
|
||||
# TODO: DELETE BELOW
|
||||
# user = 'testworker3'
|
||||
# status_msg = 'status'
|
||||
|
@ -305,7 +307,7 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
# async def queue(message):
|
||||
# an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
# queue = await cleos.aget_table(
|
||||
# 'telos.gpu', 'telos.gpu', 'queue',
|
||||
# 'gpu.scd', 'gpu.scd', 'queue',
|
||||
# index_position=2,
|
||||
# key_type='i64',
|
||||
# sort='desc',
|
||||
|
@ -314,7 +316,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
# await bot.reply_to(
|
||||
# message, f'Total requests on skynet queue: {len(queue)}')
|
||||
|
||||
|
||||
# @bot.message_handler(commands=['config'])
|
||||
# async def set_config(message):
|
||||
# user = message.from_user.id
|
||||
|
@ -361,7 +362,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
#
|
||||
# await bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
|
||||
# generic txt2img handler
|
||||
|
||||
# async def _generic_txt2img(message_or_query):
|
||||
|
@ -562,7 +562,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
# binary_data=binary
|
||||
# )
|
||||
|
||||
|
||||
# "proxy" handlers just request routers
|
||||
|
||||
# @bot.message_handler(commands=['txt2img'])
|
||||
|
@ -594,7 +593,6 @@ def create_handler_context(frontend: 'SkynetDiscordFrontend'):
|
|||
# case 'redo':
|
||||
# await _redo(call)
|
||||
|
||||
|
||||
# catch all handler for things we dont support
|
||||
|
||||
# @bot.message_handler(func=lambda message: True)
|
||||
|
|
|
@ -11,14 +11,22 @@ class SkynetView(discord.ui.View):
|
|||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
super().__init__(timeout=None)
|
||||
self.add_item(RedoButton('redo', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Txt2ImgButton('txt2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Img2ImgButton('img2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(StatsButton('stats', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(DonateButton('donate', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(ConfigButton('config', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(HelpButton('help', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(CoolButton('cool', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(RedoButton(
|
||||
'redo', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Txt2ImgButton(
|
||||
'txt2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(Img2ImgButton(
|
||||
'img2img', discord.ButtonStyle.primary, self.bot))
|
||||
self.add_item(StatsButton(
|
||||
'stats', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(DonateButton(
|
||||
'donate', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(ConfigButton(
|
||||
'config', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(HelpButton(
|
||||
'help', discord.ButtonStyle.secondary, self.bot))
|
||||
self.add_item(CoolButton(
|
||||
'cool', discord.ButtonStyle.secondary, self.bot))
|
||||
|
||||
|
||||
class Txt2ImgButton(discord.ui.Button):
|
||||
|
@ -44,9 +52,8 @@ class Txt2ImgButton(discord.ui.Button):
|
|||
prompt = msg.content
|
||||
|
||||
if len(prompt) == 0:
|
||||
await status_msg.edit(content=
|
||||
'Empty text prompt ignored.'
|
||||
)
|
||||
await status_msg.edit(content='Empty text prompt ignored.'
|
||||
)
|
||||
await db_call('update_user_request', status_msg.id, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
|
@ -111,26 +118,35 @@ class Img2ImgButton(discord.ui.Button):
|
|||
file_id = str(file.id)
|
||||
# file bytes
|
||||
image_raw = await file.read()
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
with Image.open(io.BytesIO(image_raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > 512 or h > 512:
|
||||
if w > user_config['width'] or h > user_config['height']:
|
||||
logging.warning(f'user sent img of size {image.size}')
|
||||
image.thumbnail((512, 512))
|
||||
image.thumbnail(
|
||||
(user_config['width'], user_config['height']))
|
||||
logging.warning(f'resized it to {image.size}')
|
||||
|
||||
image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
# if w > 512 or h > 512:
|
||||
# logging.warning(f'user sent img of size {image.size}')
|
||||
# image.thumbnail((512, 512))
|
||||
# logging.warning(f'resized it to {image.size}')
|
||||
# image.save(f'ipfs-docker-staging/image.png', format='PNG')
|
||||
image_loc = 'ipfs-staging/image.png'
|
||||
image.save(image_loc, format='PNG')
|
||||
|
||||
ipfs_hash = ipfs_node.add('image.png')
|
||||
ipfs_node.pin(ipfs_hash)
|
||||
ipfs_info = await ipfs_node.add(image_loc)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
await ipfs_node.pin(ipfs_hash)
|
||||
|
||||
logging.info(f'published input image {ipfs_hash} on ipfs')
|
||||
|
||||
logging.info(f'mid: {msg.id}')
|
||||
|
||||
user_config = {**user_row}
|
||||
del user_config['id']
|
||||
|
||||
params = {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
|
@ -140,8 +156,8 @@ class Img2ImgButton(discord.ui.Button):
|
|||
'update_user_stats',
|
||||
user.id,
|
||||
'img2img',
|
||||
last_file=file_id,
|
||||
last_prompt=prompt,
|
||||
last_file=file_id,
|
||||
last_binary=ipfs_hash
|
||||
)
|
||||
|
||||
|
@ -307,5 +323,3 @@ async def grab(prompt, interaction):
|
|||
await interaction.response.send_message(prompt, ephemeral=True)
|
||||
message = await interaction.client.wait_for('message', check=vet)
|
||||
return message
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
@ -32,7 +30,8 @@ class SKYExceptionHandler(ExceptionHandler):
|
|||
|
||||
|
||||
def build_redo_menu():
|
||||
btn_redo = InlineKeyboardButton("Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
btn_redo = InlineKeyboardButton(
|
||||
"Redo", callback_data=json.dumps({'method': 'redo'}))
|
||||
inline_keyboard = InlineKeyboardMarkup()
|
||||
inline_keyboard.add(btn_redo)
|
||||
return inline_keyboard
|
||||
|
@ -42,7 +41,7 @@ def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed)
|
|||
prompt = meta["prompt"]
|
||||
if len(prompt) > 256:
|
||||
prompt = prompt[:256]
|
||||
|
||||
|
||||
gen_str = f'generated by {user.name}\n'
|
||||
gen_str += f'performed by {worker}\n'
|
||||
gen_str += f'reward: {reward}\n'
|
||||
|
@ -69,7 +68,7 @@ def prepare_metainfo_caption(user, worker: str, reward: str, meta: dict, embed)
|
|||
embed.add_field(name='Parameters', value=f'```{meta_str}```', inline=False)
|
||||
|
||||
foot_str = f'Made with Skynet v{VERSION}\n'
|
||||
foot_str += f'JOIN THE SWARM: https://discord.gg/JYM4YPMgK'
|
||||
foot_str += f'JOIN THE SWARM: https://discord.gg/PAabjJtZAF'
|
||||
|
||||
embed.set_footer(text=foot_str)
|
||||
|
||||
|
@ -89,7 +88,8 @@ def generate_reply_caption(
|
|||
url=f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}',
|
||||
color=discord.Color.blue())
|
||||
|
||||
meta_info = prepare_metainfo_caption(user, worker, reward, params, explorer_link)
|
||||
meta_info = prepare_metainfo_caption(
|
||||
user, worker, reward, params, explorer_link)
|
||||
|
||||
# why do we have this?
|
||||
final_msg = '\n'.join([
|
||||
|
@ -98,10 +98,10 @@ def generate_reply_caption(
|
|||
f'PARAMETER INFO:\n{meta_info}'
|
||||
])
|
||||
|
||||
final_msg = '\n'.join([
|
||||
# f'***{explorer_link}***',
|
||||
f'{meta_info}'
|
||||
])
|
||||
# final_msg += '\n'.join([
|
||||
# # f'***{explorer_link}***',
|
||||
# f'{meta_info}'
|
||||
# ])
|
||||
|
||||
logging.info(final_msg)
|
||||
|
||||
|
@ -110,11 +110,12 @@ def generate_reply_caption(
|
|||
|
||||
async def get_global_config(cleos):
|
||||
return (await cleos.aget_table(
|
||||
'telos.gpu', 'telos.gpu', 'config'))[0]
|
||||
'gpu.scd', 'gpu.scd', 'config'))[0]
|
||||
|
||||
|
||||
async def get_user_nonce(cleos, user: str):
|
||||
return (await cleos.aget_table(
|
||||
'telos.gpu', 'telos.gpu', 'users',
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import random
|
||||
import logging
|
||||
|
@ -14,7 +12,7 @@ from contextlib import AsyncExitStack
|
|||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import Name, asset_from_str, collect_stdout
|
||||
from leap.protocol import Name, Asset
|
||||
from leap.hyperion import HyperionAPI
|
||||
|
||||
from telebot.types import InputMediaPhoto
|
||||
|
@ -43,7 +41,6 @@ class SkynetTelegramFrontend:
|
|||
db_user: str,
|
||||
db_pass: str,
|
||||
ipfs_node: str,
|
||||
remote_ipfs_node: str | None,
|
||||
key: str,
|
||||
explorer_domain: str,
|
||||
ipfs_domain: str
|
||||
|
@ -56,22 +53,19 @@ class SkynetTelegramFrontend:
|
|||
self.db_host = db_host
|
||||
self.db_user = db_user
|
||||
self.db_pass = db_pass
|
||||
self.remote_ipfs_node = remote_ipfs_node
|
||||
self.key = key
|
||||
self.explorer_domain = explorer_domain
|
||||
self.ipfs_domain = ipfs_domain
|
||||
|
||||
self.bot = AsyncTeleBot(token, exception_handler=SKYExceptionHandler)
|
||||
self.cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
self.cleos = CLEOS(endpoint=node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
self.hyperion = HyperionAPI(hyperion_url)
|
||||
self.ipfs_node = AsyncIPFSHTTP(ipfs_node)
|
||||
|
||||
self._async_exit_stack = AsyncExitStack()
|
||||
|
||||
async def start(self):
|
||||
if self.remote_ipfs_node:
|
||||
await self.ipfs_node.connect(self.remote_ipfs_node)
|
||||
|
||||
self.db_call = await self._async_exit_stack.enter_async_context(
|
||||
open_database_connection(
|
||||
self.db_user, self.db_pass, self.db_host))
|
||||
|
@ -116,7 +110,7 @@ class SkynetTelegramFrontend:
|
|||
method: str,
|
||||
params: dict,
|
||||
file_id: str | None = None,
|
||||
binary_data: str = ''
|
||||
inputs: list[str] = []
|
||||
) -> bool:
|
||||
if params['seed'] == None:
|
||||
params['seed'] = random.randint(0, 0xFFFFFFFF)
|
||||
|
@ -143,15 +137,15 @@ class SkynetTelegramFrontend:
|
|||
|
||||
reward = '20.0000 GPU'
|
||||
res = await self.cleos.a_push_action(
|
||||
'telos.gpu',
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
{
|
||||
list({
|
||||
'user': Name(self.account),
|
||||
'request_body': body,
|
||||
'binary_data': binary_data,
|
||||
'reward': asset_from_str(reward),
|
||||
'binary_data': ','.join(inputs),
|
||||
'reward': Asset.from_str(reward),
|
||||
'min_verification': 1
|
||||
},
|
||||
}.values()),
|
||||
self.account, self.key, permission=self.permission
|
||||
)
|
||||
|
||||
|
@ -176,12 +170,12 @@ class SkynetTelegramFrontend:
|
|||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
out = collect_stdout(res)
|
||||
out = res['processed']['action_traces'][0]['console']
|
||||
|
||||
request_id, nonce = out.split(':')
|
||||
|
||||
request_hash = sha256(
|
||||
(nonce + body + binary_data).encode('utf-8')).hexdigest().upper()
|
||||
(nonce + body + ','.join(inputs)).encode('utf-8')).hexdigest().upper()
|
||||
|
||||
request_id = int(request_id)
|
||||
|
||||
|
@ -189,11 +183,11 @@ class SkynetTelegramFrontend:
|
|||
|
||||
tx_hash = None
|
||||
ipfs_hash = None
|
||||
for i in range(60):
|
||||
for i in range(60 * 3):
|
||||
try:
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account=self.account,
|
||||
filter='telos.gpu:submit',
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=request_time
|
||||
)
|
||||
|
@ -241,46 +235,28 @@ class SkynetTelegramFrontend:
|
|||
user, params, tx_hash, worker, reward, self.explorer_domain)
|
||||
|
||||
# attempt to get the image and send it
|
||||
results = {}
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
||||
|
||||
async def get_and_set_results(link: str):
|
||||
res = await get_ipfs_file(link)
|
||||
logging.info(f'got response from {link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
res = await get_ipfs_file(ipfs_link)
|
||||
logging.info(f'got response from {ipfs_link}')
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
else:
|
||||
try:
|
||||
with Image.open(io.BytesIO(res.raw)) as image:
|
||||
w, h = image.size
|
||||
else:
|
||||
try:
|
||||
with Image.open(io.BytesIO(res.raw)) as image:
|
||||
w, h = image.size
|
||||
|
||||
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
|
||||
logging.warning(f'result is of size {image.size}')
|
||||
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
|
||||
if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT:
|
||||
logging.warning(f'result is of size {image.size}')
|
||||
image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT))
|
||||
|
||||
tmp_buf = io.BytesIO()
|
||||
image.save(tmp_buf, format='PNG')
|
||||
png_img = tmp_buf.getvalue()
|
||||
tmp_buf = io.BytesIO()
|
||||
image.save(tmp_buf, format='PNG')
|
||||
png_img = tmp_buf.getvalue()
|
||||
|
||||
results[link] = png_img
|
||||
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
tasks = [
|
||||
get_and_set_results(ipfs_link),
|
||||
get_and_set_results(ipfs_link_legacy)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
png_img = None
|
||||
if ipfs_link_legacy in results:
|
||||
png_img = results[ipfs_link_legacy]
|
||||
|
||||
if ipfs_link in results:
|
||||
png_img = results[ipfs_link]
|
||||
except UnidentifiedImageError:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
|
||||
if not png_img:
|
||||
await self.update_status_message(
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
@ -47,7 +45,7 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
async def queue(message):
|
||||
an_hour_ago = datetime.now() - timedelta(hours=1)
|
||||
queue = await cleos.aget_table(
|
||||
'telos.gpu', 'telos.gpu', 'queue',
|
||||
'gpu.scd', 'gpu.scd', 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
sort='desc',
|
||||
|
@ -254,7 +252,7 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
success = await work_request(
|
||||
user, status_msg, 'img2img', params,
|
||||
file_id=file_id,
|
||||
binary_data=ipfs_hash
|
||||
inputs=ipfs_hash
|
||||
)
|
||||
|
||||
if success:
|
||||
|
@ -320,7 +318,7 @@ def create_handler_context(frontend: 'SkynetTelegramFrontend'):
|
|||
success = await work_request(
|
||||
user, status_msg, 'redo', params,
|
||||
file_id=file_id,
|
||||
binary_data=binary
|
||||
inputs=binary
|
||||
)
|
||||
|
||||
if success:
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
@ -72,7 +70,7 @@ def generate_reply_caption(
|
|||
):
|
||||
explorer_link = hlink(
|
||||
'SKYNET Transaction Explorer',
|
||||
f'https://explorer.{explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
f'https://{explorer_domain}/v2/explore/transaction/{tx_hash}'
|
||||
)
|
||||
|
||||
meta_info = prepare_metainfo_caption(tguser, worker, reward, params)
|
||||
|
@ -95,11 +93,11 @@ def generate_reply_caption(
|
|||
|
||||
async def get_global_config(cleos):
|
||||
return (await cleos.aget_table(
|
||||
'telos.gpu', 'telos.gpu', 'config'))[0]
|
||||
'gpu.scd', 'gpu.scd', 'config'))[0]
|
||||
|
||||
async def get_user_nonce(cleos, user: str):
|
||||
return (await cleos.aget_table(
|
||||
'telos.gpu', 'telos.gpu', 'users',
|
||||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import asks
|
||||
import httpx
|
||||
|
||||
|
||||
class IPFSClientException(BaseException):
|
||||
class IPFSClientException(Exception):
|
||||
...
|
||||
|
||||
|
||||
|
@ -16,10 +14,11 @@ class AsyncIPFSHTTP:
|
|||
self.endpoint = endpoint
|
||||
|
||||
async def _post(self, sub_url: str, *args, **kwargs):
|
||||
resp = await asks.post(
|
||||
self.endpoint + sub_url,
|
||||
*args, **kwargs
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.endpoint + sub_url,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise IPFSClientException(resp.text)
|
||||
|
@ -28,7 +27,7 @@ class AsyncIPFSHTTP:
|
|||
|
||||
async def add(self, file_path: Path, **kwargs):
|
||||
files = {
|
||||
'file': file_path
|
||||
'file': (file_path.name, file_path.open('rb'))
|
||||
}
|
||||
return await self._post(
|
||||
'/api/v0/add',
|
||||
|
@ -55,18 +54,19 @@ class AsyncIPFSHTTP:
|
|||
))['Peers']
|
||||
|
||||
|
||||
async def get_ipfs_file(ipfs_link: str, timeout: int = 60):
|
||||
async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
|
||||
logging.info(f'attempting to get image at {ipfs_link}')
|
||||
resp = None
|
||||
for i in range(timeout):
|
||||
for _ in range(timeout):
|
||||
try:
|
||||
resp = await asks.get(ipfs_link, timeout=3)
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(ipfs_link, timeout=3)
|
||||
|
||||
except asks.errors.RequestTimeout:
|
||||
logging.warning('timeout...')
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(f'Request error: {e}')
|
||||
|
||||
except asks.errors.BadHttpResponse as e:
|
||||
logging.error(f'ifps gateway exception: \n{e}')
|
||||
if resp is not None:
|
||||
break
|
||||
|
||||
if resp:
|
||||
logging.info(f'status_code: {resp.status_code}')
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import docker
|
||||
|
||||
from docker.types import Mount
|
||||
|
||||
|
||||
@cm
|
||||
def open_ipfs_node(
|
||||
name: str = 'skynet-ipfs',
|
||||
teardown: bool = False,
|
||||
peers: list[str] = []
|
||||
):
|
||||
dclient = docker.from_env()
|
||||
|
||||
container = None
|
||||
try:
|
||||
container = dclient.containers.get(name)
|
||||
|
||||
except docker.errors.NotFound:
|
||||
data_dir = Path().resolve() / 'ipfs-docker-data'
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data_target = '/data/ipfs'
|
||||
|
||||
container = dclient.containers.run(
|
||||
'ipfs/go-ipfs:latest',
|
||||
name='skynet-ipfs',
|
||||
ports={
|
||||
'8080/tcp': 8080,
|
||||
'4001/tcp': 4001,
|
||||
'5001/tcp': ('127.0.0.1', 5001)
|
||||
},
|
||||
mounts=[
|
||||
Mount(data_target, str(data_dir), 'bind')
|
||||
],
|
||||
detach=True,
|
||||
remove=True
|
||||
)
|
||||
|
||||
uid, gid = 1000, 1000
|
||||
|
||||
if sys.platform != 'win32':
|
||||
ec, out = container.exec_run(['chown', f'{uid}:{gid}', '-R', data_target])
|
||||
logging.info(out)
|
||||
assert ec == 0
|
||||
|
||||
for log in container.logs(stream=True):
|
||||
log = log.decode().rstrip()
|
||||
logging.info(log)
|
||||
if 'Daemon is ready' in log:
|
||||
break
|
||||
|
||||
for peer in peers:
|
||||
ec, out = container.exec_run(
|
||||
['ipfs', 'swarm', 'connect', peer])
|
||||
if ec != 0:
|
||||
logging.error(out)
|
||||
|
||||
yield
|
||||
|
||||
if teardown and container:
|
||||
container.stop()
|
|
@ -1,8 +1,4 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import trio
|
||||
|
@ -44,8 +40,8 @@ class SkynetPinner:
|
|||
|
||||
async def capture_enqueues(self, after: datetime):
|
||||
enqueues = await self.hyperion.aget_actions(
|
||||
account='telos.gpu',
|
||||
filter='telos.gpu:enqueue',
|
||||
account='gpu.scd',
|
||||
filter='gpu.scd:enqueue',
|
||||
sort='desc',
|
||||
after=after.isoformat(),
|
||||
limit=1000
|
||||
|
@ -55,16 +51,16 @@ class SkynetPinner:
|
|||
|
||||
cids = []
|
||||
for action in enqueues['actions']:
|
||||
cid = action['act']['data']['binary_data']
|
||||
if cid and not self.is_pinned(cid):
|
||||
cids.append(cid)
|
||||
for cid in action['act']['data']['binary_data'].split(','):
|
||||
if cid and not self.is_pinned(cid):
|
||||
cids.append(cid)
|
||||
|
||||
return cids
|
||||
|
||||
async def capture_submits(self, after: datetime):
|
||||
submits = await self.hyperion.aget_actions(
|
||||
account='telos.gpu',
|
||||
filter='telos.gpu:submit',
|
||||
account='gpu.scd',
|
||||
filter='gpu.scd:submit',
|
||||
sort='desc',
|
||||
after=after.isoformat(),
|
||||
limit=1000
|
||||
|
@ -118,8 +114,8 @@ class SkynetPinner:
|
|||
for cid in cids:
|
||||
n.start_soon(self.task_pin, cid)
|
||||
|
||||
except OSError as e:
|
||||
traceback.print_exc()
|
||||
except OSError:
|
||||
logging.exception('OSError while trying to pin?')
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
|
145
skynet/nodeos.py
145
skynet/nodeos.py
|
@ -1,145 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import docker
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import get_container, Symbol
|
||||
|
||||
|
||||
@cm
|
||||
def open_nodeos(cleanup: bool = True):
|
||||
dclient = docker.from_env()
|
||||
vtestnet = get_container(
|
||||
dclient,
|
||||
'guilledk/skynet:leap-4.0.1',
|
||||
name='skynet-nodeos',
|
||||
force_unique=True,
|
||||
detach=True,
|
||||
network='host')
|
||||
|
||||
try:
|
||||
cleos = CLEOS(
|
||||
dclient, vtestnet,
|
||||
url='http://127.0.0.1:42000',
|
||||
remote='http://127.0.0.1:42000'
|
||||
)
|
||||
|
||||
cleos.start_keosd()
|
||||
|
||||
priv, pub = cleos.create_key_pair()
|
||||
logging.info(f'SUDO KEYS: {(priv, pub)}')
|
||||
|
||||
cleos.setup_wallet(priv)
|
||||
|
||||
genesis = json.dumps({
|
||||
"initial_timestamp": '2017-08-29T02:14:00.000',
|
||||
"initial_key": pub,
|
||||
"initial_configuration": {
|
||||
"max_block_net_usage": 1048576,
|
||||
"target_block_net_usage_pct": 1000,
|
||||
"max_transaction_net_usage": 1048575,
|
||||
"base_per_transaction_net_usage": 12,
|
||||
"net_usage_leeway": 500,
|
||||
"context_free_discount_net_usage_num": 20,
|
||||
"context_free_discount_net_usage_den": 100,
|
||||
"max_block_cpu_usage": 200000,
|
||||
"target_block_cpu_usage_pct": 1000,
|
||||
"max_transaction_cpu_usage": 150000,
|
||||
"min_transaction_cpu_usage": 100,
|
||||
"max_transaction_lifetime": 3600,
|
||||
"deferred_trx_expiration_window": 600,
|
||||
"max_transaction_delay": 3888000,
|
||||
"max_inline_action_size": 4096,
|
||||
"max_inline_action_depth": 4,
|
||||
"max_authority_depth": 6
|
||||
}
|
||||
}, indent=4)
|
||||
|
||||
ec, out = cleos.run(
|
||||
['bash', '-c', f'echo \'{genesis}\' > /root/skynet.json'])
|
||||
assert ec == 0
|
||||
|
||||
place_holder = 'EOS5fLreY5Zq5owBhmNJTgQaLqQ4ufzXSTpStQakEyfxNFuUEgNs1=KEY:5JnvSc6pewpHHuUHwvbJopsew6AKwiGnexwDRc2Pj2tbdw6iML9'
|
||||
sig_provider = f'{pub}=KEY:{priv}'
|
||||
nodeos_config_ini = '/root/nodeos/config.ini'
|
||||
ec, out = cleos.run(
|
||||
['bash', '-c', f'sed -i -e \'s/{place_holder}/{sig_provider}/g\' {nodeos_config_ini}'])
|
||||
assert ec == 0
|
||||
|
||||
cleos.start_nodeos_from_config(
|
||||
nodeos_config_ini,
|
||||
data_dir='/root/nodeos/data',
|
||||
genesis='/root/skynet.json',
|
||||
state_plugin=True)
|
||||
|
||||
time.sleep(0.5)
|
||||
cleos.wait_blocks(1)
|
||||
cleos.boot_sequence(token_sym=Symbol('GPU', 4))
|
||||
|
||||
priv, pub = cleos.create_key_pair()
|
||||
cleos.import_key(priv)
|
||||
cleos.private_keys['telos.gpu'] = priv
|
||||
logging.info(f'GPU KEYS: {(priv, pub)}')
|
||||
cleos.new_account('telos.gpu', ram=4200000, key=pub)
|
||||
|
||||
for i in range(1, 4):
|
||||
priv, pub = cleos.create_key_pair()
|
||||
cleos.import_key(priv)
|
||||
cleos.private_keys[f'testworker{i}'] = priv
|
||||
logging.info(f'testworker{i} KEYS: {(priv, pub)}')
|
||||
cleos.create_account_staked(
|
||||
'eosio', f'testworker{i}', key=pub)
|
||||
|
||||
priv, pub = cleos.create_key_pair()
|
||||
cleos.import_key(priv)
|
||||
logging.info(f'TELEGRAM KEYS: {(priv, pub)}')
|
||||
cleos.create_account_staked(
|
||||
'eosio', 'telegram', ram=500000, key=pub)
|
||||
|
||||
cleos.transfer_token(
|
||||
'eosio', 'telegram', '1000000.0000 GPU', 'Initial testing funds')
|
||||
|
||||
cleos.deploy_contract_from_host(
|
||||
'telos.gpu',
|
||||
'tests/contracts/telos.gpu',
|
||||
verify_hash=False,
|
||||
create_account=False
|
||||
)
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu',
|
||||
'config',
|
||||
['eosio.token', '4,GPU'],
|
||||
f'telos.gpu@active'
|
||||
)
|
||||
assert ec == 0
|
||||
|
||||
ec, out = cleos.transfer_token(
|
||||
'telegram', 'telos.gpu', '1000000.0000 GPU', 'Initial testing funds')
|
||||
assert ec == 0
|
||||
|
||||
user_row = cleos.get_table(
|
||||
'telos.gpu',
|
||||
'telos.gpu',
|
||||
'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound='telegram',
|
||||
upper_bound='telegram'
|
||||
)
|
||||
assert len(user_row) == 1
|
||||
|
||||
yield cleos
|
||||
|
||||
finally:
|
||||
# ec, out = cleos.list_all_keys()
|
||||
# logging.info(out)
|
||||
if cleanup:
|
||||
vtestnet.stop()
|
||||
vtestnet.remove()
|
238
skynet/utils.py
238
skynet/utils.py
|
@ -1,238 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import asks
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from realesrgan import RealESRGANer
|
||||
from huggingface_hub import login
|
||||
import trio
|
||||
|
||||
from .constants import MODELS
|
||||
|
||||
|
||||
def time_ms():
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
||||
# return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
||||
return np.asarray(img)
|
||||
|
||||
|
||||
def convert_from_bytes_to_img(raw: bytes) -> Image:
|
||||
return Image.open(io.BytesIO(raw))
|
||||
|
||||
|
||||
def convert_from_img_to_bytes(image: Image, fmt='PNG') -> bytes:
|
||||
byte_arr = io.BytesIO()
|
||||
image.save(byte_arr, format=fmt)
|
||||
return byte_arr.getvalue()
|
||||
|
||||
|
||||
def crop_image(image: Image, max_w: int, max_h: int) -> Image:
|
||||
w, h = image.size
|
||||
if w > max_w or h > max_h:
|
||||
image.thumbnail((max_w, max_h))
|
||||
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def pipeline_for(
|
||||
model: str,
|
||||
mem_fraction: float = 1.0,
|
||||
image: bool = False,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# full determinism
|
||||
# https://huggingface.co/docs/diffusers/using-diffusers/reproducibility#deterministic-algorithms
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
model_info = MODELS[model]
|
||||
|
||||
req_mem = model_info['mem']
|
||||
mem_gb = torch.cuda.mem_get_info()[1] / (10**9)
|
||||
mem_gb *= mem_fraction
|
||||
over_mem = mem_gb < req_mem
|
||||
if over_mem:
|
||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||
|
||||
shortname = model_info['short']
|
||||
|
||||
params = {
|
||||
'safety_checker': None,
|
||||
'torch_dtype': torch.float16,
|
||||
'cache_dir': cache_dir,
|
||||
'variant': 'fp16'
|
||||
}
|
||||
|
||||
match shortname:
|
||||
case 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model, **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if over_mem:
|
||||
if not image:
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
else:
|
||||
if sys.version_info[1] < 11:
|
||||
# torch.compile only supported on python < 3.11
|
||||
pipe.unet = torch.compile(
|
||||
pipe.unet, mode='reduce-overhead', fullgraph=True)
|
||||
|
||||
pipe = pipe.to('cuda')
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def txt2img(
|
||||
hf_token: str,
|
||||
model: str = 'prompthero/openjourney',
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
output: str = 'output.png',
|
||||
width: int = 512, height: int = 512,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model)
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def img2img(
|
||||
hf_token: str,
|
||||
model: str = 'prompthero/openjourney',
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
strength: float = 1.0,
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, image=True)
|
||||
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=input_img,
|
||||
strength=strength,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
model_path=model_path,
|
||||
dni_weight=None,
|
||||
model=RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4
|
||||
),
|
||||
half=True
|
||||
)
|
||||
|
||||
def upscale(
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
model_path: str = 'weights/RealESRGAN_x4plus.pth'
|
||||
):
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
|
||||
upscaler = init_upscaler(model_path=model_path)
|
||||
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
image.save(output)
|
||||
|
||||
|
||||
async def download_upscaler():
|
||||
print('downloading upscaler...')
|
||||
weights_path = Path('weights')
|
||||
weights_path.mkdir(exist_ok=True)
|
||||
upscaler_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
|
||||
save_path = weights_path / 'RealESRGAN_x4plus.pth'
|
||||
response = await asks.get(upscaler_url)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print('done')
|
||||
|
||||
def download_all_models(hf_token: str, hf_home: str):
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
trio.run(download_upscaler)
|
||||
|
||||
login(token=hf_token)
|
||||
for model in MODELS:
|
||||
print(f'DOWNLOADING {model.upper()}')
|
||||
pipeline_for(model, cache_dir=hf_home)
|
|
@ -1,24 +1,54 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import pytest
|
||||
|
||||
from skynet.db import open_new_database
|
||||
from skynet.config import *
|
||||
from skynet.ipfs import AsyncIPFSHTTP
|
||||
from skynet.ipfs.docker import open_ipfs_node
|
||||
from skynet.nodeos import open_nodeos
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def ipfs_client():
|
||||
with open_ipfs_node(teardown=True):
|
||||
yield AsyncIPFSHTTP('http://127.0.0.1:5001')
|
||||
yield AsyncIPFSHTTP('http://127.0.0.1:5001')
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def postgres_db():
|
||||
from skynet.db import open_new_database
|
||||
with open_new_database() as db_params:
|
||||
yield db_params
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def cleos():
|
||||
with open_nodeos() as cli:
|
||||
yield cli
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def skynet_cleos(cleos_bs):
|
||||
cleos = cleos_bs
|
||||
|
||||
priv, pub = cleos.create_key_pair()
|
||||
cleos.import_key('telos.gpu', priv)
|
||||
cleos.new_account('telos.gpu', ram=4200000, key=pub)
|
||||
|
||||
cleos.deploy_contract_from_path(
|
||||
'telos.gpu',
|
||||
'tests/contracts/telos.gpu',
|
||||
create_account=False
|
||||
)
|
||||
|
||||
cleos.push_action(
|
||||
'telos.gpu',
|
||||
'config',
|
||||
['eosio.token', '4,GPU'],
|
||||
'telos.gpu'
|
||||
)
|
||||
|
||||
yield cleos
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inject_mockers():
|
||||
from skynet.constants import MODELS, ModelDesc
|
||||
|
||||
MODELS['skygpu/txt2img-mocker'] = ModelDesc(
|
||||
short='tester',
|
||||
mem=0.01,
|
||||
attrs={},
|
||||
tags=['txt2img']
|
||||
)
|
||||
|
||||
yield
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
def test_dev(skynet_cleos):
|
||||
cleos = skynet_cleos
|
||||
...
|
|
@ -1,106 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import time
|
||||
import json
|
||||
|
||||
from hashlib import sha256
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import requests
|
||||
from skynet.constants import DEFAULT_IPFS_REMOTE
|
||||
|
||||
from skynet.dgpu import open_dgpu_node
|
||||
|
||||
from leap.sugar import collect_stdout
|
||||
|
||||
|
||||
def test_enqueue_work(cleos):
|
||||
user = 'telegram'
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'algo': 'midj',
|
||||
'prompt': 'skynet terminator dystopic',
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'guidance': 10,
|
||||
'step': 28,
|
||||
'seed': 420,
|
||||
'upscaler': 'x4'
|
||||
}
|
||||
})
|
||||
binary = ''
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu', 'enqueue', [user, req, binary, '20.0000 GPU', 1], f'{user}@active'
|
||||
)
|
||||
|
||||
assert ec == 0
|
||||
|
||||
queue = cleos.get_table('telos.gpu', 'telos.gpu', 'queue')
|
||||
|
||||
assert len(queue) == 1
|
||||
|
||||
req_on_chain = queue[0]
|
||||
|
||||
assert req_on_chain['user'] == user
|
||||
assert req_on_chain['body'] == req
|
||||
assert req_on_chain['binary_data'] == binary
|
||||
|
||||
trio.run(
|
||||
partial(
|
||||
open_dgpu_node,
|
||||
f'testworker1',
|
||||
'active',
|
||||
cleos,
|
||||
DEFAULT_IPFS_REMOTE,
|
||||
cleos.private_keys['testworker1'],
|
||||
initial_algos=['midj']
|
||||
)
|
||||
)
|
||||
|
||||
queue = cleos.get_table('telos.gpu', 'telos.gpu', 'queue')
|
||||
|
||||
assert len(queue) == 0
|
||||
|
||||
|
||||
def test_enqueue_dequeue(cleos):
|
||||
user = 'telegram'
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'algo': 'midj',
|
||||
'prompt': 'skynet terminator dystopic',
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'guidance': 10,
|
||||
'step': 28,
|
||||
'seed': 420,
|
||||
'upscaler': 'x4'
|
||||
}
|
||||
})
|
||||
binary = ''
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu', 'enqueue', [user, req, binary, '20.0000 GPU', 1], f'{user}@active'
|
||||
)
|
||||
|
||||
assert ec == 0
|
||||
|
||||
request_id, _ = collect_stdout(out).split(':')
|
||||
request_id = int(request_id)
|
||||
|
||||
queue = cleos.get_table('telos.gpu', 'telos.gpu', 'queue')
|
||||
|
||||
assert len(queue) == 1
|
||||
|
||||
ec, out = cleos.push_action(
|
||||
'telos.gpu', 'dequeue', [user, request_id], f'{user}@active'
|
||||
)
|
||||
|
||||
assert ec == 0
|
||||
|
||||
queue = cleos.get_table('telos.gpu', 'telos.gpu', 'queue')
|
||||
|
||||
assert len(queue) == 0
|
|
@ -1,6 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import pytest
|
||||
|
||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [
|
||||
('diffuse'), ('txt2img')
|
||||
])
|
||||
async def test_pipeline_mocker(inject_mockers, mode):
|
||||
model = 'skygpu/txt2img-mocker'
|
||||
params = {
|
||||
"prompt": "Kronos God Realistic 4k",
|
||||
"model": model,
|
||||
"step": 21,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 168402949,
|
||||
"guidance": "7.5"
|
||||
}
|
||||
|
||||
with maybe_load_model(model, mode) as model:
|
||||
compute_one(model, 0, mode, params)
|
||||
|
||||
|
||||
async def test_pipeline():
|
||||
model = 'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
mode = 'txt2img'
|
||||
params = {
|
||||
"prompt": "Kronos God Realistic 4k",
|
||||
"model": model,
|
||||
"step": 21,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 168402949,
|
||||
"guidance": "7.5"
|
||||
}
|
||||
|
||||
with maybe_load_model(model, mode) as model:
|
||||
compute_one(model, 0, mode, params)
|
Loading…
Reference in New Issue