mirror of https://github.com/skygpu/skynet.git
commit
79901c85ca
|
@ -1,3 +1,4 @@
|
||||||
|
skynet.ini
|
||||||
.python-version
|
.python-version
|
||||||
hf_home
|
hf_home
|
||||||
outputs
|
outputs
|
||||||
|
|
|
@ -32,3 +32,4 @@ env HF_HOME /hf_home
|
||||||
copy scripts scripts
|
copy scripts scripts
|
||||||
copy tests tests
|
copy tests tests
|
||||||
|
|
||||||
|
expose 40000-45000
|
||||||
|
|
665
LICENSE
665
LICENSE
|
@ -1,11 +1,662 @@
|
||||||
A menos que sea especificamente indicado en el cabezal del archivo, se reservan
|
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||||
todos los derechos sobre este codigo por parte de:
|
Version 3, 19 November 2007
|
||||||
|
|
||||||
Guillermo Rodriguez, guillermor@fing.edu.uy
|
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
ENGLISH LICENSE:
|
Preamble
|
||||||
|
|
||||||
Unless specifically indicated in the file header, all rights to this code are
|
The GNU Affero General Public License is a free, copyleft license for
|
||||||
reserved by:
|
software and other kinds of works, specifically designed to ensure
|
||||||
|
cooperation with the community in the case of network server software.
|
||||||
|
|
||||||
|
The licenses for most software and other practical works are designed
|
||||||
|
to take away your freedom to share and change the works. By contrast,
|
||||||
|
our General Public Licenses are intended to guarantee your freedom to
|
||||||
|
share and change all versions of a program--to make sure it remains free
|
||||||
|
software for all its users.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
them if you wish), that you receive source code or can get it if you
|
||||||
|
want it, that you can change the software or use pieces of it in new
|
||||||
|
free programs, and that you know you can do these things.
|
||||||
|
|
||||||
|
Developers that use our General Public Licenses protect your rights
|
||||||
|
with two steps: (1) assert copyright on the software, and (2) offer
|
||||||
|
you this License which gives you legal permission to copy, distribute
|
||||||
|
and/or modify the software.
|
||||||
|
|
||||||
|
A secondary benefit of defending all users' freedom is that
|
||||||
|
improvements made in alternate versions of the program, if they
|
||||||
|
receive widespread use, become available for other developers to
|
||||||
|
incorporate. Many developers of free software are heartened and
|
||||||
|
encouraged by the resulting cooperation. However, in the case of
|
||||||
|
software used on network servers, this result may fail to come about.
|
||||||
|
The GNU General Public License permits making a modified version and
|
||||||
|
letting the public access it on a server without ever releasing its
|
||||||
|
source code to the public.
|
||||||
|
|
||||||
|
The GNU Affero General Public License is designed specifically to
|
||||||
|
ensure that, in such cases, the modified source code becomes available
|
||||||
|
to the community. It requires the operator of a network server to
|
||||||
|
provide the source code of the modified version running there to the
|
||||||
|
users of that server. Therefore, public use of a modified version, on
|
||||||
|
a publicly accessible server, gives the public access to the source
|
||||||
|
code of the modified version.
|
||||||
|
|
||||||
|
An older license, called the Affero General Public License and
|
||||||
|
published by Affero, was designed to accomplish similar goals. This is
|
||||||
|
a different license, not a version of the Affero GPL, but Affero has
|
||||||
|
released a new version of the Affero GPL which permits relicensing under
|
||||||
|
this license.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
0. Definitions.
|
||||||
|
|
||||||
|
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||||
|
|
||||||
|
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||||
|
works, such as semiconductor masks.
|
||||||
|
|
||||||
|
"The Program" refers to any copyrightable work licensed under this
|
||||||
|
License. Each licensee is addressed as "you". "Licensees" and
|
||||||
|
"recipients" may be individuals or organizations.
|
||||||
|
|
||||||
|
To "modify" a work means to copy from or adapt all or part of the work
|
||||||
|
in a fashion requiring copyright permission, other than the making of an
|
||||||
|
exact copy. The resulting work is called a "modified version" of the
|
||||||
|
earlier work or a work "based on" the earlier work.
|
||||||
|
|
||||||
|
A "covered work" means either the unmodified Program or a work based
|
||||||
|
on the Program.
|
||||||
|
|
||||||
|
To "propagate" a work means to do anything with it that, without
|
||||||
|
permission, would make you directly or secondarily liable for
|
||||||
|
infringement under applicable copyright law, except executing it on a
|
||||||
|
computer or modifying a private copy. Propagation includes copying,
|
||||||
|
distribution (with or without modification), making available to the
|
||||||
|
public, and in some countries other activities as well.
|
||||||
|
|
||||||
|
To "convey" a work means any kind of propagation that enables other
|
||||||
|
parties to make or receive copies. Mere interaction with a user through
|
||||||
|
a computer network, with no transfer of a copy, is not conveying.
|
||||||
|
|
||||||
|
An interactive user interface displays "Appropriate Legal Notices"
|
||||||
|
to the extent that it includes a convenient and prominently visible
|
||||||
|
feature that (1) displays an appropriate copyright notice, and (2)
|
||||||
|
tells the user that there is no warranty for the work (except to the
|
||||||
|
extent that warranties are provided), that licensees may convey the
|
||||||
|
work under this License, and how to view a copy of this License. If
|
||||||
|
the interface presents a list of user commands or options, such as a
|
||||||
|
menu, a prominent item in the list meets this criterion.
|
||||||
|
|
||||||
|
1. Source Code.
|
||||||
|
|
||||||
|
The "source code" for a work means the preferred form of the work
|
||||||
|
for making modifications to it. "Object code" means any non-source
|
||||||
|
form of a work.
|
||||||
|
|
||||||
|
A "Standard Interface" means an interface that either is an official
|
||||||
|
standard defined by a recognized standards body, or, in the case of
|
||||||
|
interfaces specified for a particular programming language, one that
|
||||||
|
is widely used among developers working in that language.
|
||||||
|
|
||||||
|
The "System Libraries" of an executable work include anything, other
|
||||||
|
than the work as a whole, that (a) is included in the normal form of
|
||||||
|
packaging a Major Component, but which is not part of that Major
|
||||||
|
Component, and (b) serves only to enable use of the work with that
|
||||||
|
Major Component, or to implement a Standard Interface for which an
|
||||||
|
implementation is available to the public in source code form. A
|
||||||
|
"Major Component", in this context, means a major essential component
|
||||||
|
(kernel, window system, and so on) of the specific operating system
|
||||||
|
(if any) on which the executable work runs, or a compiler used to
|
||||||
|
produce the work, or an object code interpreter used to run it.
|
||||||
|
|
||||||
|
The "Corresponding Source" for a work in object code form means all
|
||||||
|
the source code needed to generate, install, and (for an executable
|
||||||
|
work) run the object code and to modify the work, including scripts to
|
||||||
|
control those activities. However, it does not include the work's
|
||||||
|
System Libraries, or general-purpose tools or generally available free
|
||||||
|
programs which are used unmodified in performing those activities but
|
||||||
|
which are not part of the work. For example, Corresponding Source
|
||||||
|
includes interface definition files associated with source files for
|
||||||
|
the work, and the source code for shared libraries and dynamically
|
||||||
|
linked subprograms that the work is specifically designed to require,
|
||||||
|
such as by intimate data communication or control flow between those
|
||||||
|
subprograms and other parts of the work.
|
||||||
|
|
||||||
|
The Corresponding Source need not include anything that users
|
||||||
|
can regenerate automatically from other parts of the Corresponding
|
||||||
|
Source.
|
||||||
|
|
||||||
|
The Corresponding Source for a work in source code form is that
|
||||||
|
same work.
|
||||||
|
|
||||||
|
2. Basic Permissions.
|
||||||
|
|
||||||
|
All rights granted under this License are granted for the term of
|
||||||
|
copyright on the Program, and are irrevocable provided the stated
|
||||||
|
conditions are met. This License explicitly affirms your unlimited
|
||||||
|
permission to run the unmodified Program. The output from running a
|
||||||
|
covered work is covered by this License only if the output, given its
|
||||||
|
content, constitutes a covered work. This License acknowledges your
|
||||||
|
rights of fair use or other equivalent, as provided by copyright law.
|
||||||
|
|
||||||
|
You may make, run and propagate covered works that you do not
|
||||||
|
convey, without conditions so long as your license otherwise remains
|
||||||
|
in force. You may convey covered works to others for the sole purpose
|
||||||
|
of having them make modifications exclusively for you, or provide you
|
||||||
|
with facilities for running those works, provided that you comply with
|
||||||
|
the terms of this License in conveying all material for which you do
|
||||||
|
not control copyright. Those thus making or running the covered works
|
||||||
|
for you must do so exclusively on your behalf, under your direction
|
||||||
|
and control, on terms that prohibit them from making any copies of
|
||||||
|
your copyrighted material outside their relationship with you.
|
||||||
|
|
||||||
|
Conveying under any other circumstances is permitted solely under
|
||||||
|
the conditions stated below. Sublicensing is not allowed; section 10
|
||||||
|
makes it unnecessary.
|
||||||
|
|
||||||
|
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||||
|
|
||||||
|
No covered work shall be deemed part of an effective technological
|
||||||
|
measure under any applicable law fulfilling obligations under article
|
||||||
|
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||||
|
similar laws prohibiting or restricting circumvention of such
|
||||||
|
measures.
|
||||||
|
|
||||||
|
When you convey a covered work, you waive any legal power to forbid
|
||||||
|
circumvention of technological measures to the extent such circumvention
|
||||||
|
is effected by exercising rights under this License with respect to
|
||||||
|
the covered work, and you disclaim any intention to limit operation or
|
||||||
|
modification of the work as a means of enforcing, against the work's
|
||||||
|
users, your or third parties' legal rights to forbid circumvention of
|
||||||
|
technological measures.
|
||||||
|
|
||||||
|
4. Conveying Verbatim Copies.
|
||||||
|
|
||||||
|
You may convey verbatim copies of the Program's source code as you
|
||||||
|
receive it, in any medium, provided that you conspicuously and
|
||||||
|
appropriately publish on each copy an appropriate copyright notice;
|
||||||
|
keep intact all notices stating that this License and any
|
||||||
|
non-permissive terms added in accord with section 7 apply to the code;
|
||||||
|
keep intact all notices of the absence of any warranty; and give all
|
||||||
|
recipients a copy of this License along with the Program.
|
||||||
|
|
||||||
|
You may charge any price or no price for each copy that you convey,
|
||||||
|
and you may offer support or warranty protection for a fee.
|
||||||
|
|
||||||
|
5. Conveying Modified Source Versions.
|
||||||
|
|
||||||
|
You may convey a work based on the Program, or the modifications to
|
||||||
|
produce it from the Program, in the form of source code under the
|
||||||
|
terms of section 4, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) The work must carry prominent notices stating that you modified
|
||||||
|
it, and giving a relevant date.
|
||||||
|
|
||||||
|
b) The work must carry prominent notices stating that it is
|
||||||
|
released under this License and any conditions added under section
|
||||||
|
7. This requirement modifies the requirement in section 4 to
|
||||||
|
"keep intact all notices".
|
||||||
|
|
||||||
|
c) You must license the entire work, as a whole, under this
|
||||||
|
License to anyone who comes into possession of a copy. This
|
||||||
|
License will therefore apply, along with any applicable section 7
|
||||||
|
additional terms, to the whole of the work, and all its parts,
|
||||||
|
regardless of how they are packaged. This License gives no
|
||||||
|
permission to license the work in any other way, but it does not
|
||||||
|
invalidate such permission if you have separately received it.
|
||||||
|
|
||||||
|
d) If the work has interactive user interfaces, each must display
|
||||||
|
Appropriate Legal Notices; however, if the Program has interactive
|
||||||
|
interfaces that do not display Appropriate Legal Notices, your
|
||||||
|
work need not make them do so.
|
||||||
|
|
||||||
|
A compilation of a covered work with other separate and independent
|
||||||
|
works, which are not by their nature extensions of the covered work,
|
||||||
|
and which are not combined with it such as to form a larger program,
|
||||||
|
in or on a volume of a storage or distribution medium, is called an
|
||||||
|
"aggregate" if the compilation and its resulting copyright are not
|
||||||
|
used to limit the access or legal rights of the compilation's users
|
||||||
|
beyond what the individual works permit. Inclusion of a covered work
|
||||||
|
in an aggregate does not cause this License to apply to the other
|
||||||
|
parts of the aggregate.
|
||||||
|
|
||||||
|
6. Conveying Non-Source Forms.
|
||||||
|
|
||||||
|
You may convey a covered work in object code form under the terms
|
||||||
|
of sections 4 and 5, provided that you also convey the
|
||||||
|
machine-readable Corresponding Source under the terms of this License,
|
||||||
|
in one of these ways:
|
||||||
|
|
||||||
|
a) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by the
|
||||||
|
Corresponding Source fixed on a durable physical medium
|
||||||
|
customarily used for software interchange.
|
||||||
|
|
||||||
|
b) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by a
|
||||||
|
written offer, valid for at least three years and valid for as
|
||||||
|
long as you offer spare parts or customer support for that product
|
||||||
|
model, to give anyone who possesses the object code either (1) a
|
||||||
|
copy of the Corresponding Source for all the software in the
|
||||||
|
product that is covered by this License, on a durable physical
|
||||||
|
medium customarily used for software interchange, for a price no
|
||||||
|
more than your reasonable cost of physically performing this
|
||||||
|
conveying of source, or (2) access to copy the
|
||||||
|
Corresponding Source from a network server at no charge.
|
||||||
|
|
||||||
|
c) Convey individual copies of the object code with a copy of the
|
||||||
|
written offer to provide the Corresponding Source. This
|
||||||
|
alternative is allowed only occasionally and noncommercially, and
|
||||||
|
only if you received the object code with such an offer, in accord
|
||||||
|
with subsection 6b.
|
||||||
|
|
||||||
|
d) Convey the object code by offering access from a designated
|
||||||
|
place (gratis or for a charge), and offer equivalent access to the
|
||||||
|
Corresponding Source in the same way through the same place at no
|
||||||
|
further charge. You need not require recipients to copy the
|
||||||
|
Corresponding Source along with the object code. If the place to
|
||||||
|
copy the object code is a network server, the Corresponding Source
|
||||||
|
may be on a different server (operated by you or a third party)
|
||||||
|
that supports equivalent copying facilities, provided you maintain
|
||||||
|
clear directions next to the object code saying where to find the
|
||||||
|
Corresponding Source. Regardless of what server hosts the
|
||||||
|
Corresponding Source, you remain obligated to ensure that it is
|
||||||
|
available for as long as needed to satisfy these requirements.
|
||||||
|
|
||||||
|
e) Convey the object code using peer-to-peer transmission, provided
|
||||||
|
you inform other peers where the object code and Corresponding
|
||||||
|
Source of the work are being offered to the general public at no
|
||||||
|
charge under subsection 6d.
|
||||||
|
|
||||||
|
A separable portion of the object code, whose source code is excluded
|
||||||
|
from the Corresponding Source as a System Library, need not be
|
||||||
|
included in conveying the object code work.
|
||||||
|
|
||||||
|
A "User Product" is either (1) a "consumer product", which means any
|
||||||
|
tangible personal property which is normally used for personal, family,
|
||||||
|
or household purposes, or (2) anything designed or sold for incorporation
|
||||||
|
into a dwelling. In determining whether a product is a consumer product,
|
||||||
|
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||||
|
product received by a particular user, "normally used" refers to a
|
||||||
|
typical or common use of that class of product, regardless of the status
|
||||||
|
of the particular user or of the way in which the particular user
|
||||||
|
actually uses, or expects or is expected to use, the product. A product
|
||||||
|
is a consumer product regardless of whether the product has substantial
|
||||||
|
commercial, industrial or non-consumer uses, unless such uses represent
|
||||||
|
the only significant mode of use of the product.
|
||||||
|
|
||||||
|
"Installation Information" for a User Product means any methods,
|
||||||
|
procedures, authorization keys, or other information required to install
|
||||||
|
and execute modified versions of a covered work in that User Product from
|
||||||
|
a modified version of its Corresponding Source. The information must
|
||||||
|
suffice to ensure that the continued functioning of the modified object
|
||||||
|
code is in no case prevented or interfered with solely because
|
||||||
|
modification has been made.
|
||||||
|
|
||||||
|
If you convey an object code work under this section in, or with, or
|
||||||
|
specifically for use in, a User Product, and the conveying occurs as
|
||||||
|
part of a transaction in which the right of possession and use of the
|
||||||
|
User Product is transferred to the recipient in perpetuity or for a
|
||||||
|
fixed term (regardless of how the transaction is characterized), the
|
||||||
|
Corresponding Source conveyed under this section must be accompanied
|
||||||
|
by the Installation Information. But this requirement does not apply
|
||||||
|
if neither you nor any third party retains the ability to install
|
||||||
|
modified object code on the User Product (for example, the work has
|
||||||
|
been installed in ROM).
|
||||||
|
|
||||||
|
The requirement to provide Installation Information does not include a
|
||||||
|
requirement to continue to provide support service, warranty, or updates
|
||||||
|
for a work that has been modified or installed by the recipient, or for
|
||||||
|
the User Product in which it has been modified or installed. Access to a
|
||||||
|
network may be denied when the modification itself materially and
|
||||||
|
adversely affects the operation of the network or violates the rules and
|
||||||
|
protocols for communication across the network.
|
||||||
|
|
||||||
|
Corresponding Source conveyed, and Installation Information provided,
|
||||||
|
in accord with this section must be in a format that is publicly
|
||||||
|
documented (and with an implementation available to the public in
|
||||||
|
source code form), and must require no special password or key for
|
||||||
|
unpacking, reading or copying.
|
||||||
|
|
||||||
|
7. Additional Terms.
|
||||||
|
|
||||||
|
"Additional permissions" are terms that supplement the terms of this
|
||||||
|
License by making exceptions from one or more of its conditions.
|
||||||
|
Additional permissions that are applicable to the entire Program shall
|
||||||
|
be treated as though they were included in this License, to the extent
|
||||||
|
that they are valid under applicable law. If additional permissions
|
||||||
|
apply only to part of the Program, that part may be used separately
|
||||||
|
under those permissions, but the entire Program remains governed by
|
||||||
|
this License without regard to the additional permissions.
|
||||||
|
|
||||||
|
When you convey a copy of a covered work, you may at your option
|
||||||
|
remove any additional permissions from that copy, or from any part of
|
||||||
|
it. (Additional permissions may be written to require their own
|
||||||
|
removal in certain cases when you modify the work.) You may place
|
||||||
|
additional permissions on material, added by you to a covered work,
|
||||||
|
for which you have or can give appropriate copyright permission.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, for material you
|
||||||
|
add to a covered work, you may (if authorized by the copyright holders of
|
||||||
|
that material) supplement the terms of this License with terms:
|
||||||
|
|
||||||
|
a) Disclaiming warranty or limiting liability differently from the
|
||||||
|
terms of sections 15 and 16 of this License; or
|
||||||
|
|
||||||
|
b) Requiring preservation of specified reasonable legal notices or
|
||||||
|
author attributions in that material or in the Appropriate Legal
|
||||||
|
Notices displayed by works containing it; or
|
||||||
|
|
||||||
|
c) Prohibiting misrepresentation of the origin of that material, or
|
||||||
|
requiring that modified versions of such material be marked in
|
||||||
|
reasonable ways as different from the original version; or
|
||||||
|
|
||||||
|
d) Limiting the use for publicity purposes of names of licensors or
|
||||||
|
authors of the material; or
|
||||||
|
|
||||||
|
e) Declining to grant rights under trademark law for use of some
|
||||||
|
trade names, trademarks, or service marks; or
|
||||||
|
|
||||||
|
f) Requiring indemnification of licensors and authors of that
|
||||||
|
material by anyone who conveys the material (or modified versions of
|
||||||
|
it) with contractual assumptions of liability to the recipient, for
|
||||||
|
any liability that these contractual assumptions directly impose on
|
||||||
|
those licensors and authors.
|
||||||
|
|
||||||
|
All other non-permissive additional terms are considered "further
|
||||||
|
restrictions" within the meaning of section 10. If the Program as you
|
||||||
|
received it, or any part of it, contains a notice stating that it is
|
||||||
|
governed by this License along with a term that is a further
|
||||||
|
restriction, you may remove that term. If a license document contains
|
||||||
|
a further restriction but permits relicensing or conveying under this
|
||||||
|
License, you may add to a covered work material governed by the terms
|
||||||
|
of that license document, provided that the further restriction does
|
||||||
|
not survive such relicensing or conveying.
|
||||||
|
|
||||||
|
If you add terms to a covered work in accord with this section, you
|
||||||
|
must place, in the relevant source files, a statement of the
|
||||||
|
additional terms that apply to those files, or a notice indicating
|
||||||
|
where to find the applicable terms.
|
||||||
|
|
||||||
|
Additional terms, permissive or non-permissive, may be stated in the
|
||||||
|
form of a separately written license, or stated as exceptions;
|
||||||
|
the above requirements apply either way.
|
||||||
|
|
||||||
|
8. Termination.
|
||||||
|
|
||||||
|
You may not propagate or modify a covered work except as expressly
|
||||||
|
provided under this License. Any attempt otherwise to propagate or
|
||||||
|
modify it is void, and will automatically terminate your rights under
|
||||||
|
this License (including any patent licenses granted under the third
|
||||||
|
paragraph of section 11).
|
||||||
|
|
||||||
|
However, if you cease all violation of this License, then your
|
||||||
|
license from a particular copyright holder is reinstated (a)
|
||||||
|
provisionally, unless and until the copyright holder explicitly and
|
||||||
|
finally terminates your license, and (b) permanently, if the copyright
|
||||||
|
holder fails to notify you of the violation by some reasonable means
|
||||||
|
prior to 60 days after the cessation.
|
||||||
|
|
||||||
|
Moreover, your license from a particular copyright holder is
|
||||||
|
reinstated permanently if the copyright holder notifies you of the
|
||||||
|
violation by some reasonable means, this is the first time you have
|
||||||
|
received notice of violation of this License (for any work) from that
|
||||||
|
copyright holder, and you cure the violation prior to 30 days after
|
||||||
|
your receipt of the notice.
|
||||||
|
|
||||||
|
Termination of your rights under this section does not terminate the
|
||||||
|
licenses of parties who have received copies or rights from you under
|
||||||
|
this License. If your rights have been terminated and not permanently
|
||||||
|
reinstated, you do not qualify to receive new licenses for the same
|
||||||
|
material under section 10.
|
||||||
|
|
||||||
|
9. Acceptance Not Required for Having Copies.
|
||||||
|
|
||||||
|
You are not required to accept this License in order to receive or
|
||||||
|
run a copy of the Program. Ancillary propagation of a covered work
|
||||||
|
occurring solely as a consequence of using peer-to-peer transmission
|
||||||
|
to receive a copy likewise does not require acceptance. However,
|
||||||
|
nothing other than this License grants you permission to propagate or
|
||||||
|
modify any covered work. These actions infringe copyright if you do
|
||||||
|
not accept this License. Therefore, by modifying or propagating a
|
||||||
|
covered work, you indicate your acceptance of this License to do so.
|
||||||
|
|
||||||
|
10. Automatic Licensing of Downstream Recipients.
|
||||||
|
|
||||||
|
Each time you convey a covered work, the recipient automatically
|
||||||
|
receives a license from the original licensors, to run, modify and
|
||||||
|
propagate that work, subject to this License. You are not responsible
|
||||||
|
for enforcing compliance by third parties with this License.
|
||||||
|
|
||||||
|
An "entity transaction" is a transaction transferring control of an
|
||||||
|
organization, or substantially all assets of one, or subdividing an
|
||||||
|
organization, or merging organizations. If propagation of a covered
|
||||||
|
work results from an entity transaction, each party to that
|
||||||
|
transaction who receives a copy of the work also receives whatever
|
||||||
|
licenses to the work the party's predecessor in interest had or could
|
||||||
|
give under the previous paragraph, plus a right to possession of the
|
||||||
|
Corresponding Source of the work from the predecessor in interest, if
|
||||||
|
the predecessor has it or can get it with reasonable efforts.
|
||||||
|
|
||||||
|
You may not impose any further restrictions on the exercise of the
|
||||||
|
rights granted or affirmed under this License. For example, you may
|
||||||
|
not impose a license fee, royalty, or other charge for exercise of
|
||||||
|
rights granted under this License, and you may not initiate litigation
|
||||||
|
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||||
|
any patent claim is infringed by making, using, selling, offering for
|
||||||
|
sale, or importing the Program or any portion of it.
|
||||||
|
|
||||||
|
11. Patents.
|
||||||
|
|
||||||
|
A "contributor" is a copyright holder who authorizes use under this
|
||||||
|
License of the Program or a work on which the Program is based. The
|
||||||
|
work thus licensed is called the contributor's "contributor version".
|
||||||
|
|
||||||
|
A contributor's "essential patent claims" are all patent claims
|
||||||
|
owned or controlled by the contributor, whether already acquired or
|
||||||
|
hereafter acquired, that would be infringed by some manner, permitted
|
||||||
|
by this License, of making, using, or selling its contributor version,
|
||||||
|
but do not include claims that would be infringed only as a
|
||||||
|
consequence of further modification of the contributor version. For
|
||||||
|
purposes of this definition, "control" includes the right to grant
|
||||||
|
patent sublicenses in a manner consistent with the requirements of
|
||||||
|
this License.
|
||||||
|
|
||||||
|
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||||
|
patent license under the contributor's essential patent claims, to
|
||||||
|
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||||
|
propagate the contents of its contributor version.
|
||||||
|
|
||||||
|
In the following three paragraphs, a "patent license" is any express
|
||||||
|
agreement or commitment, however denominated, not to enforce a patent
|
||||||
|
(such as an express permission to practice a patent or covenant not to
|
||||||
|
sue for patent infringement). To "grant" such a patent license to a
|
||||||
|
party means to make such an agreement or commitment not to enforce a
|
||||||
|
patent against the party.
|
||||||
|
|
||||||
|
If you convey a covered work, knowingly relying on a patent license,
|
||||||
|
and the Corresponding Source of the work is not available for anyone
|
||||||
|
to copy, free of charge and under the terms of this License, through a
|
||||||
|
publicly available network server or other readily accessible means,
|
||||||
|
then you must either (1) cause the Corresponding Source to be so
|
||||||
|
available, or (2) arrange to deprive yourself of the benefit of the
|
||||||
|
patent license for this particular work, or (3) arrange, in a manner
|
||||||
|
consistent with the requirements of this License, to extend the patent
|
||||||
|
license to downstream recipients. "Knowingly relying" means you have
|
||||||
|
actual knowledge that, but for the patent license, your conveying the
|
||||||
|
covered work in a country, or your recipient's use of the covered work
|
||||||
|
in a country, would infringe one or more identifiable patents in that
|
||||||
|
country that you have reason to believe are valid.
|
||||||
|
|
||||||
|
If, pursuant to or in connection with a single transaction or
|
||||||
|
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||||
|
covered work, and grant a patent license to some of the parties
|
||||||
|
receiving the covered work authorizing them to use, propagate, modify
|
||||||
|
or convey a specific copy of the covered work, then the patent license
|
||||||
|
you grant is automatically extended to all recipients of the covered
|
||||||
|
work and works based on it.
|
||||||
|
|
||||||
|
A patent license is "discriminatory" if it does not include within
|
||||||
|
the scope of its coverage, prohibits the exercise of, or is
|
||||||
|
conditioned on the non-exercise of one or more of the rights that are
|
||||||
|
specifically granted under this License. You may not convey a covered
|
||||||
|
work if you are a party to an arrangement with a third party that is
|
||||||
|
in the business of distributing software, under which you make payment
|
||||||
|
to the third party based on the extent of your activity of conveying
|
||||||
|
the work, and under which the third party grants, to any of the
|
||||||
|
parties who would receive the covered work from you, a discriminatory
|
||||||
|
patent license (a) in connection with copies of the covered work
|
||||||
|
conveyed by you (or copies made from those copies), or (b) primarily
|
||||||
|
for and in connection with specific products or compilations that
|
||||||
|
contain the covered work, unless you entered into that arrangement,
|
||||||
|
or that patent license was granted, prior to 28 March 2007.
|
||||||
|
|
||||||
|
Nothing in this License shall be construed as excluding or limiting
|
||||||
|
any implied license or other defenses to infringement that may
|
||||||
|
otherwise be available to you under applicable patent law.
|
||||||
|
|
||||||
|
12. No Surrender of Others' Freedom.
|
||||||
|
|
||||||
|
If conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot convey a
|
||||||
|
covered work so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you may
|
||||||
|
not convey it at all. For example, if you agree to terms that obligate you
|
||||||
|
to collect a royalty for further conveying from those to whom you convey
|
||||||
|
the Program, the only way you could satisfy both those terms and this
|
||||||
|
License would be to refrain entirely from conveying the Program.
|
||||||
|
|
||||||
|
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, if you modify the
|
||||||
|
Program, your modified version must prominently offer all users
|
||||||
|
interacting with it remotely through a computer network (if your version
|
||||||
|
supports such interaction) an opportunity to receive the Corresponding
|
||||||
|
Source of your version by providing access to the Corresponding Source
|
||||||
|
from a network server at no charge, through some standard or customary
|
||||||
|
means of facilitating copying of software. This Corresponding Source
|
||||||
|
shall include the Corresponding Source for any work covered by version 3
|
||||||
|
of the GNU General Public License that is incorporated pursuant to the
|
||||||
|
following paragraph.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, you have
|
||||||
|
permission to link or combine any covered work with a work licensed
|
||||||
|
under version 3 of the GNU General Public License into a single
|
||||||
|
combined work, and to convey the resulting work. The terms of this
|
||||||
|
License will continue to apply to the part which is the covered work,
|
||||||
|
but the work with which it is combined will remain governed by version
|
||||||
|
3 of the GNU General Public License.
|
||||||
|
|
||||||
|
14. Revised Versions of this License.
|
||||||
|
|
||||||
|
The Free Software Foundation may publish revised and/or new versions of
|
||||||
|
the GNU Affero General Public License from time to time. Such new versions
|
||||||
|
will be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the
|
||||||
|
Program specifies that a certain numbered version of the GNU Affero General
|
||||||
|
Public License "or any later version" applies to it, you have the
|
||||||
|
option of following the terms and conditions either of that numbered
|
||||||
|
version or of any later version published by the Free Software
|
||||||
|
Foundation. If the Program does not specify a version number of the
|
||||||
|
GNU Affero General Public License, you may choose any version ever published
|
||||||
|
by the Free Software Foundation.
|
||||||
|
|
||||||
|
If the Program specifies that a proxy can decide which future
|
||||||
|
versions of the GNU Affero General Public License can be used, that proxy's
|
||||||
|
public statement of acceptance of a version permanently authorizes you
|
||||||
|
to choose that version for the Program.
|
||||||
|
|
||||||
|
Later license versions may give you additional or different
|
||||||
|
permissions. However, no additional obligations are imposed on any
|
||||||
|
author or copyright holder as a result of your choosing to follow a
|
||||||
|
later version.
|
||||||
|
|
||||||
|
15. Disclaimer of Warranty.
|
||||||
|
|
||||||
|
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||||
|
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||||
|
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||||
|
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||||
|
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||||
|
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
16. Limitation of Liability.
|
||||||
|
|
||||||
|
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||||
|
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||||
|
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||||
|
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||||
|
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||||
|
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||||
|
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||||
|
SUCH DAMAGES.
|
||||||
|
|
||||||
|
17. Interpretation of Sections 15 and 16.
|
||||||
|
|
||||||
|
If the disclaimer of warranty and limitation of liability provided
|
||||||
|
above cannot be given local legal effect according to their terms,
|
||||||
|
reviewing courts shall apply local law that most closely approximates
|
||||||
|
an absolute waiver of all civil liability in connection with the
|
||||||
|
Program, unless a warranty or assumption of liability accompanies a
|
||||||
|
copy of the Program in return for a fee.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
state the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published
|
||||||
|
by the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If your software can interact with users remotely through a computer
|
||||||
|
network, you should also make sure that it provides a way for users to
|
||||||
|
get its source. For example, if your program is a web application, its
|
||||||
|
interface could display a "Source" link that leads users to an archive
|
||||||
|
of the code. There are many ways you could offer source, and different
|
||||||
|
solutions will be better for different programs; see section 13 for the
|
||||||
|
specific requirements.
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or school,
|
||||||
|
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||||
|
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||||
|
<https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
Guillermo Rodriguez, guillermor@.edu.uy
|
|
||||||
|
|
|
@ -3,4 +3,4 @@ pytest
|
||||||
pytest-trio
|
pytest-trio
|
||||||
psycopg2-binary
|
psycopg2-binary
|
||||||
|
|
||||||
git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
|
git+https://github.com/guilledk/pytest-dockerctl.git@multi_names#egg=pytest-dockerctl
|
||||||
|
|
|
@ -9,3 +9,5 @@ protobuf
|
||||||
pyOpenSSL
|
pyOpenSSL
|
||||||
trio_asyncio
|
trio_asyncio
|
||||||
pyTelegramBotAPI
|
pyTelegramBotAPI
|
||||||
|
|
||||||
|
git+https://github.com/goodboy/tractor.git@master#egg=tractor
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
[skynet]
|
||||||
|
certs_dir = certs
|
||||||
|
|
||||||
|
[skynet.dgpu]
|
||||||
|
hf_home = hf_home
|
||||||
|
hf_token = hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx
|
||||||
|
|
||||||
|
[skynet.telegram]
|
||||||
|
token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
|
|
||||||
|
[skynet.telegram-test]
|
||||||
|
token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
504
skynet/brain.py
504
skynet/brain.py
|
@ -1,35 +1,24 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
import zlib
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
from pathlib import Path
|
|
||||||
from functools import partial
|
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import asynccontextmanager as acm
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pynng
|
|
||||||
import trio_asyncio
|
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import Context
|
||||||
from OpenSSL.crypto import (
|
|
||||||
load_privatekey,
|
|
||||||
load_certificate,
|
|
||||||
FILETYPE_PEM
|
|
||||||
)
|
|
||||||
|
|
||||||
from .db import *
|
from .utils import time_ms
|
||||||
|
from .network import *
|
||||||
|
from .protobuf import *
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
from .protobuf import *
|
|
||||||
|
|
||||||
|
|
||||||
|
class SkynetRPCBadRequest(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
class SkynetDGPUOffline(BaseException):
|
class SkynetDGPUOffline(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -44,39 +33,71 @@ class SkynetShutdownRequested(BaseException):
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@acm
|
||||||
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
async def run_skynet(
|
||||||
|
rpc_address: str = DEFAULT_RPC_ADDR
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info('skynet is starting')
|
||||||
|
|
||||||
nodes = OrderedDict()
|
nodes = OrderedDict()
|
||||||
wip_reqs = {}
|
|
||||||
fin_reqs = {}
|
|
||||||
heartbeats = {}
|
heartbeats = {}
|
||||||
next_worker: Optional[int] = None
|
next_worker: Optional[int] = None
|
||||||
security = len(tls_whitelist) > 0
|
|
||||||
|
|
||||||
def connect_node(uid):
|
def connect_node(req: SkynetRPCRequest):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
nodes[uid] = {
|
|
||||||
'task': None
|
|
||||||
}
|
|
||||||
logging.info(f'dgpu online: {uid}')
|
|
||||||
|
|
||||||
if not next_worker:
|
node_params = MessageToDict(req.params)
|
||||||
next_worker = 0
|
logging.info(f'got node params {node_params}')
|
||||||
|
|
||||||
|
if 'dgpu_addr' not in node_params:
|
||||||
|
raise SkynetRPCBadRequest(
|
||||||
|
f'DGPU connection params don\'t include dgpu addr')
|
||||||
|
|
||||||
|
session = SessionClient(
|
||||||
|
node_params['dgpu_addr'],
|
||||||
|
'skynet',
|
||||||
|
cert_name='brain.cert',
|
||||||
|
key_name='brain.key',
|
||||||
|
ca_name=node_params['cert']
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
session.connect()
|
||||||
|
|
||||||
|
node = {
|
||||||
|
'task': None,
|
||||||
|
'session': session
|
||||||
|
}
|
||||||
|
node.update(node_params)
|
||||||
|
|
||||||
|
nodes[req.uid] = node
|
||||||
|
logging.info(f'DGPU node online: {req.uid}')
|
||||||
|
|
||||||
|
if not next_worker:
|
||||||
|
next_worker = 0
|
||||||
|
|
||||||
|
except pynng.exceptions.ConnectionRefused:
|
||||||
|
logging.warning(f'error while dialing dgpu node... dropping...')
|
||||||
|
raise SkynetDGPUOffline('Connection to dgpu node addr failed.')
|
||||||
|
|
||||||
def disconnect_node(uid):
|
def disconnect_node(uid):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
if uid not in nodes:
|
if uid not in nodes:
|
||||||
|
logging.warning(f'Attempt to disconnect unknown node {uid}')
|
||||||
return
|
return
|
||||||
|
|
||||||
i = list(nodes.keys()).index(uid)
|
i = list(nodes.keys()).index(uid)
|
||||||
|
nodes[uid]['session'].disconnect()
|
||||||
del nodes[uid]
|
del nodes[uid]
|
||||||
|
|
||||||
if i < next_worker:
|
if i < next_worker:
|
||||||
next_worker -= 1
|
next_worker -= 1
|
||||||
|
|
||||||
|
logging.warning(f'DGPU node offline: {uid}')
|
||||||
|
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
logging.info('nw: None')
|
logging.info('All nodes disconnected.')
|
||||||
next_worker = None
|
next_worker = None
|
||||||
|
|
||||||
logging.warning(f'dgpu offline: {uid}')
|
|
||||||
|
|
||||||
def is_worker_busy(nid: str):
|
def is_worker_busy(nid: str):
|
||||||
return nodes[nid]['task'] != None
|
return nodes[nid]['task'] != None
|
||||||
|
@ -90,8 +111,6 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
def get_next_worker():
|
def get_next_worker():
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
logging.info('get next_worker called')
|
|
||||||
logging.info(f'pre next_worker: {next_worker}')
|
|
||||||
|
|
||||||
if next_worker == None:
|
if next_worker == None:
|
||||||
raise SkynetDGPUOffline('No workers connected, try again later')
|
raise SkynetDGPUOffline('No workers connected, try again later')
|
||||||
|
@ -113,392 +132,79 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
if next_worker >= len(nodes):
|
if next_worker >= len(nodes):
|
||||||
next_worker = 0
|
next_worker = 0
|
||||||
|
|
||||||
logging.info(f'post next_worker: {next_worker}')
|
|
||||||
|
|
||||||
return nid
|
return nid
|
||||||
|
|
||||||
async def dgpu_heartbeat_service():
|
async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
|
||||||
nonlocal heartbeats
|
result = {'ok': {}}
|
||||||
while True:
|
|
||||||
await trio.sleep(60)
|
|
||||||
rid = uuid.uuid4().hex
|
|
||||||
beat_msg = DGPUBusMessage(
|
|
||||||
rid=rid,
|
|
||||||
nid='',
|
|
||||||
method='heartbeat'
|
|
||||||
)
|
|
||||||
heartbeats.clear()
|
|
||||||
heartbeats[rid] = int(time.time() * 1000)
|
|
||||||
await dgpu_bus.asend(beat_msg.SerializeToString())
|
|
||||||
logging.info('sent heartbeat')
|
|
||||||
|
|
||||||
async def dgpu_bus_streamer():
|
|
||||||
nonlocal wip_reqs, fin_reqs, heartbeats
|
|
||||||
while True:
|
|
||||||
raw_msg = await dgpu_bus.arecv()
|
|
||||||
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
|
||||||
msg = DGPUBusMessage()
|
|
||||||
msg.ParseFromString(raw_msg)
|
|
||||||
|
|
||||||
if security:
|
|
||||||
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
|
|
||||||
|
|
||||||
rid = msg.rid
|
|
||||||
|
|
||||||
if msg.method == 'heartbeat':
|
|
||||||
sent_time = heartbeats[rid]
|
|
||||||
delta = msg.params['time'] - sent_time
|
|
||||||
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if rid not in wip_reqs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if msg.method == 'binary-reply':
|
|
||||||
logging.info('bin reply, recv extra data')
|
|
||||||
raw_img = await dgpu_bus.arecv()
|
|
||||||
msg = (msg, raw_img)
|
|
||||||
|
|
||||||
fin_reqs[rid] = msg
|
|
||||||
event = wip_reqs[rid]
|
|
||||||
event.set()
|
|
||||||
del wip_reqs[rid]
|
|
||||||
|
|
||||||
async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
|
|
||||||
nonlocal wip_reqs, fin_reqs, next_worker
|
|
||||||
nid = get_next_worker()
|
|
||||||
idx = list(nodes.keys()).index(nid)
|
|
||||||
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
|
|
||||||
rid = uuid.uuid4().hex
|
|
||||||
ack_event = trio.Event()
|
|
||||||
img_event = trio.Event()
|
|
||||||
wip_reqs[rid] = ack_event
|
|
||||||
|
|
||||||
nodes[nid]['task'] = rid
|
|
||||||
|
|
||||||
dgpu_req = DGPUBusMessage(
|
|
||||||
rid=rid,
|
|
||||||
nid=nid,
|
|
||||||
method='diffuse')
|
|
||||||
dgpu_req.params.update(req.to_dict())
|
|
||||||
|
|
||||||
if security:
|
|
||||||
dgpu_req.auth.cert = 'skynet'
|
|
||||||
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
|
||||||
|
|
||||||
msg = dgpu_req.SerializeToString()
|
|
||||||
if img_buf:
|
|
||||||
logging.info(f'sending img of size {len(img_buf)} as attachment')
|
|
||||||
logging.info(img_buf[:10])
|
|
||||||
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
|
|
||||||
|
|
||||||
await dgpu_bus.asend(msg)
|
|
||||||
|
|
||||||
with trio.move_on_after(4):
|
|
||||||
await ack_event.wait()
|
|
||||||
|
|
||||||
logging.info(f'ack event: {ack_event.is_set()}')
|
|
||||||
|
|
||||||
if not ack_event.is_set():
|
|
||||||
disconnect_node(nid)
|
|
||||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
|
||||||
|
|
||||||
ack_msg = fin_reqs[rid]
|
|
||||||
if 'ack' not in ack_msg.params:
|
|
||||||
disconnect_node(nid)
|
|
||||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
|
||||||
|
|
||||||
wip_reqs[rid] = img_event
|
|
||||||
with trio.move_on_after(30):
|
|
||||||
await img_event.wait()
|
|
||||||
|
|
||||||
logging.info(f'img event: {ack_event.is_set()}')
|
|
||||||
|
|
||||||
if not img_event.is_set():
|
|
||||||
disconnect_node(nid)
|
|
||||||
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
|
||||||
|
|
||||||
nodes[nid]['task'] = None
|
|
||||||
|
|
||||||
resp = fin_reqs[rid]
|
|
||||||
del fin_reqs[rid]
|
|
||||||
if isinstance(resp, tuple):
|
|
||||||
meta, img = resp
|
|
||||||
return rid, img, meta.params
|
|
||||||
|
|
||||||
raise SkynetDGPUComputeError(MessageToDict(resp.params))
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_user_request(rpc_ctx, req):
|
|
||||||
try:
|
|
||||||
async with db_pool.acquire() as conn:
|
|
||||||
user = await get_or_create_user(conn, req.uid)
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
match req.method:
|
|
||||||
case 'txt2img':
|
|
||||||
logging.info('txt2img')
|
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
|
||||||
del user_config['id']
|
|
||||||
user_config.update(MessageToDict(req.params))
|
|
||||||
|
|
||||||
req = DiffusionParameters(**user_config, image=False)
|
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
|
||||||
logging.info(f'done streaming {rid}')
|
|
||||||
result = {
|
|
||||||
'id': rid,
|
|
||||||
'img': img.hex(),
|
|
||||||
'meta': meta
|
|
||||||
}
|
|
||||||
|
|
||||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
|
||||||
logging.info('updated user stats.')
|
|
||||||
|
|
||||||
case 'img2img':
|
|
||||||
logging.info('img2img')
|
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
|
||||||
del user_config['id']
|
|
||||||
|
|
||||||
params = MessageToDict(req.params)
|
|
||||||
img_buf = bytes.fromhex(params['img'])
|
|
||||||
del params['img']
|
|
||||||
user_config.update(params)
|
|
||||||
|
|
||||||
req = DiffusionParameters(**user_config, image=True)
|
|
||||||
|
|
||||||
if not req.image:
|
|
||||||
raise AssertionError('Didn\'t enable image flag for img2img?')
|
|
||||||
|
|
||||||
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
|
|
||||||
logging.info(f'done streaming {rid}')
|
|
||||||
result = {
|
|
||||||
'id': rid,
|
|
||||||
'img': img.hex(),
|
|
||||||
'meta': meta
|
|
||||||
}
|
|
||||||
|
|
||||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
|
||||||
logging.info('updated user stats.')
|
|
||||||
|
|
||||||
case 'redo':
|
|
||||||
logging.info('redo')
|
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
|
||||||
del user_config['id']
|
|
||||||
prompt = await get_last_prompt_of(conn, user)
|
|
||||||
|
|
||||||
if prompt:
|
|
||||||
req = DiffusionParameters(
|
|
||||||
prompt=prompt,
|
|
||||||
**user_config,
|
|
||||||
image=False
|
|
||||||
)
|
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
|
||||||
result = {
|
|
||||||
'id': rid,
|
|
||||||
'img': img.hex(),
|
|
||||||
'meta': meta
|
|
||||||
}
|
|
||||||
await update_user_stats(conn, user)
|
|
||||||
logging.info('updated user stats.')
|
|
||||||
|
|
||||||
else:
|
|
||||||
result = {
|
|
||||||
'error': 'skynet_no_last_prompt',
|
|
||||||
'message': 'No prompt to redo, do txt2img first'
|
|
||||||
}
|
|
||||||
|
|
||||||
case 'config':
|
|
||||||
logging.info('config')
|
|
||||||
if req.params['attr'] in CONFIG_ATTRS:
|
|
||||||
logging.info(f'update: {req.params}')
|
|
||||||
await update_user_config(
|
|
||||||
conn, user, req.params['attr'], req.params['val'])
|
|
||||||
logging.info('done')
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.warning(f'{req.params["attr"]} not in {CONFIG_ATTRS}')
|
|
||||||
|
|
||||||
case 'stats':
|
|
||||||
logging.info('stats')
|
|
||||||
generated, joined, role = await get_user_stats(conn, user)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'generated': generated,
|
|
||||||
'joined': joined.strftime(DATE_FORMAT),
|
|
||||||
'role': role
|
|
||||||
}
|
|
||||||
|
|
||||||
case _:
|
|
||||||
logging.warn('unknown method')
|
|
||||||
|
|
||||||
except SkynetDGPUOffline as e:
|
|
||||||
result = {
|
|
||||||
'error': 'skynet_dgpu_offline',
|
|
||||||
'message': str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
except SkynetDGPUOverloaded as e:
|
|
||||||
result = {
|
|
||||||
'error': 'skynet_dgpu_overloaded',
|
|
||||||
'message': str(e),
|
|
||||||
'nodes': len(nodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
except SkynetDGPUComputeError as e:
|
|
||||||
result = {
|
|
||||||
'error': 'skynet_dgpu_compute_error',
|
|
||||||
'message': str(e)
|
|
||||||
}
|
|
||||||
except BaseException as e:
|
|
||||||
traceback.print_exception(type(e), e, e.__traceback__)
|
|
||||||
result = {
|
|
||||||
'error': 'skynet_internal_error',
|
|
||||||
'message': str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = SkynetRPCResponse()
|
resp = SkynetRPCResponse()
|
||||||
resp.result.update(result)
|
|
||||||
|
|
||||||
if security:
|
|
||||||
resp.auth.cert = 'skynet'
|
|
||||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
|
||||||
|
|
||||||
logging.info('sending response')
|
|
||||||
await rpc_ctx.asend(resp.SerializeToString())
|
|
||||||
rpc_ctx.close()
|
|
||||||
logging.info('done')
|
|
||||||
|
|
||||||
async def request_service(n):
|
|
||||||
nonlocal next_worker
|
|
||||||
while True:
|
|
||||||
ctx = sock.new_context()
|
|
||||||
req = SkynetRPCRequest()
|
|
||||||
req.ParseFromString(await ctx.arecv())
|
|
||||||
|
|
||||||
if security:
|
|
||||||
if req.auth.cert not in tls_whitelist:
|
|
||||||
logging.warning(
|
|
||||||
f'{req.cert} not in tls whitelist and security=True')
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
logging.warning(
|
|
||||||
f'{req.cert} sent an unauthenticated msg with security=True')
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
|
try:
|
||||||
match req.method:
|
match req.method:
|
||||||
case 'skynet_shutdown':
|
|
||||||
raise SkynetShutdownRequested
|
|
||||||
|
|
||||||
case 'dgpu_online':
|
case 'dgpu_online':
|
||||||
connect_node(req.uid)
|
connect_node(req)
|
||||||
|
|
||||||
|
case 'dgpu_call':
|
||||||
|
nid = get_next_worker()
|
||||||
|
idx = list(nodes.keys()).index(nid)
|
||||||
|
node = nodes[nid]
|
||||||
|
logging.info(f'dgpu_call {idx}/{len(nodes)} {nid} @ {node["dgpu_addr"]}')
|
||||||
|
dgpu_time = await node['session'].rpc('dgpu_time')
|
||||||
|
if 'ok' not in dgpu_time.result:
|
||||||
|
status = MessageToDict(dgpu_time.result)
|
||||||
|
logging.warning(json.dumps(status, indent=4))
|
||||||
|
disconnect_node(nid)
|
||||||
|
raise SkynetDGPUComputeError(status['error'])
|
||||||
|
|
||||||
|
dgpu_time = dgpu_time.result['ok']
|
||||||
|
logging.info(f'ping to {nid}: {time_ms() - dgpu_time} ms')
|
||||||
|
|
||||||
|
try:
|
||||||
|
dgpu_result = await node['session'].rpc(
|
||||||
|
timeout=45, # give this 45 sec to run cause its compute
|
||||||
|
binext=req.bin,
|
||||||
|
**req.params
|
||||||
|
)
|
||||||
|
result = MessageToDict(dgpu_result.result)
|
||||||
|
|
||||||
|
if dgpu_result.bin:
|
||||||
|
resp.bin = dgpu_result.bin
|
||||||
|
|
||||||
|
except trio.TooSlowError:
|
||||||
|
result = {'error': 'timeout while processing request'}
|
||||||
|
|
||||||
case 'dgpu_offline':
|
case 'dgpu_offline':
|
||||||
disconnect_node(req.uid)
|
disconnect_node(req.uid)
|
||||||
|
|
||||||
case 'dgpu_workers':
|
case 'dgpu_workers':
|
||||||
result = len(nodes)
|
result = {'ok': len(nodes)}
|
||||||
|
|
||||||
case 'dgpu_next':
|
case 'dgpu_next':
|
||||||
result = next_worker
|
result = {'ok': next_worker}
|
||||||
|
|
||||||
case 'heartbeat':
|
case 'skynet_shutdown':
|
||||||
logging.info('beat')
|
raise SkynetShutdownRequested
|
||||||
result = {'time': time.time()}
|
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
n.start_soon(
|
logging.warning(f'Unknown method {req.method}')
|
||||||
handle_user_request, ctx, req)
|
result = {'error': 'unknown method'}
|
||||||
continue
|
|
||||||
|
|
||||||
resp = SkynetRPCResponse()
|
except BaseException as e:
|
||||||
resp.result.update({'ok': result})
|
result = {'error': str(e)}
|
||||||
|
|
||||||
if security:
|
resp.result.update(result)
|
||||||
resp.auth.cert = 'skynet'
|
|
||||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
|
||||||
|
|
||||||
await ctx.asend(resp.SerializeToString())
|
return resp
|
||||||
|
|
||||||
ctx.close()
|
rpc_server = SessionServer(
|
||||||
|
rpc_address,
|
||||||
|
rpc_handler,
|
||||||
|
cert_name='brain.cert',
|
||||||
|
key_name='brain.key'
|
||||||
|
)
|
||||||
|
|
||||||
|
async with rpc_server.open():
|
||||||
async with trio.open_nursery() as n:
|
logging.info('rpc server is up')
|
||||||
n.start_soon(dgpu_bus_streamer)
|
|
||||||
n.start_soon(dgpu_heartbeat_service)
|
|
||||||
n.start_soon(request_service, n)
|
|
||||||
logging.info('starting rpc service')
|
|
||||||
yield
|
yield
|
||||||
logging.info('stopping rpc service')
|
logging.info('skynet is shuting down...')
|
||||||
n.cancel_scope.cancel()
|
|
||||||
|
|
||||||
|
logging.info('skynet down.')
|
||||||
@acm
|
|
||||||
async def run_skynet(
|
|
||||||
db_user: str = DB_USER,
|
|
||||||
db_pass: str = DB_PASS,
|
|
||||||
db_host: str = DB_HOST,
|
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
|
||||||
security: bool = True
|
|
||||||
):
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logging.info('skynet is starting')
|
|
||||||
|
|
||||||
tls_config = None
|
|
||||||
if security:
|
|
||||||
# load tls certs
|
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
|
||||||
|
|
||||||
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
|
||||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
|
||||||
|
|
||||||
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
|
||||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
|
||||||
|
|
||||||
tls_whitelist = {}
|
|
||||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
|
|
||||||
tls_whitelist[cert_path.stem] = load_certificate(
|
|
||||||
FILETYPE_PEM, cert_path.read_text())
|
|
||||||
|
|
||||||
cert_start = tls_cert_data.index('\n') + 1
|
|
||||||
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
|
|
||||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
|
||||||
|
|
||||||
rpc_address = 'tls+' + rpc_address
|
|
||||||
dgpu_address = 'tls+' + dgpu_address
|
|
||||||
tls_config = TLSConfig(
|
|
||||||
TLSConfig.MODE_SERVER,
|
|
||||||
own_key_string=tls_key_data,
|
|
||||||
own_cert_string=tls_cert_data)
|
|
||||||
|
|
||||||
with (
|
|
||||||
pynng.Rep0(recv_max_size=0) as rpc_sock,
|
|
||||||
pynng.Bus0(recv_max_size=0) as dgpu_bus
|
|
||||||
):
|
|
||||||
async with open_database_connection(
|
|
||||||
db_user, db_pass, db_host) as db_pool:
|
|
||||||
|
|
||||||
logging.info('connected to db.')
|
|
||||||
if security:
|
|
||||||
rpc_sock.tls_config = tls_config
|
|
||||||
dgpu_bus.tls_config = tls_config
|
|
||||||
|
|
||||||
rpc_sock.listen(rpc_address)
|
|
||||||
dgpu_bus.listen(dgpu_address)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with open_rpc_service(
|
|
||||||
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|
||||||
yield
|
|
||||||
|
|
||||||
except SkynetShutdownRequested:
|
|
||||||
...
|
|
||||||
|
|
||||||
logging.info('disconnected from db.')
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ if torch_enabled:
|
||||||
from .dgpu import open_dgpu_node
|
from .dgpu import open_dgpu_node
|
||||||
|
|
||||||
from .brain import run_skynet
|
from .brain import run_skynet
|
||||||
|
from .config import *
|
||||||
from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR
|
from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR
|
||||||
|
|
||||||
from .frontend.telegram import run_skynet_telegram
|
from .frontend.telegram import run_skynet_telegram
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ def skynet(*args, **kwargs):
|
||||||
@click.option('--steps', '-s', default=26)
|
@click.option('--steps', '-s', default=26)
|
||||||
@click.option('--seed', '-S', default=None)
|
@click.option('--seed', '-S', default=None)
|
||||||
def txt2img(*args, **kwargs):
|
def txt2img(*args, **kwargs):
|
||||||
assert 'HF_TOKEN' in os.environ
|
_, hf_token, _, cfg = init_env_from_config()
|
||||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
utils.txt2img(hf_token, **kwargs)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--model', '-m', default='midj')
|
@click.option('--model', '-m', default='midj')
|
||||||
|
@ -52,9 +52,9 @@ def txt2img(*args, **kwargs):
|
||||||
@click.option('--steps', '-s', default=26)
|
@click.option('--steps', '-s', default=26)
|
||||||
@click.option('--seed', '-S', default=None)
|
@click.option('--seed', '-S', default=None)
|
||||||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||||
assert 'HF_TOKEN' in os.environ
|
_, hf_token, _, cfg = init_env_from_config()
|
||||||
utils.img2img(
|
utils.img2img(
|
||||||
os.environ['HF_TOKEN'],
|
hf_token,
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
img_path=input,
|
img_path=input,
|
||||||
|
@ -76,6 +76,12 @@ def upscale(input, output, model):
|
||||||
model_path=model)
|
model_path=model)
|
||||||
|
|
||||||
|
|
||||||
|
@skynet.command()
|
||||||
|
def download():
|
||||||
|
_, hf_token, _, cfg = init_env_from_config()
|
||||||
|
utils.download_all_models(hf_token)
|
||||||
|
|
||||||
|
|
||||||
@skynet.group()
|
@skynet.group()
|
||||||
def run(*args, **kwargs):
|
def run(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -85,29 +91,17 @@ def run(*args, **kwargs):
|
||||||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--host', '-H', default=DEFAULT_RPC_ADDR)
|
'--host', '-H', default=DEFAULT_RPC_ADDR)
|
||||||
@click.option(
|
|
||||||
'--host-dgpu', '-D', default=DEFAULT_DGPU_ADDR)
|
|
||||||
@click.option(
|
|
||||||
'--db-host', '-h', default='localhost:5432')
|
|
||||||
@click.option(
|
|
||||||
'--db-pass', '-p', default='password')
|
|
||||||
def brain(
|
def brain(
|
||||||
loglevel: str,
|
loglevel: str,
|
||||||
host: str,
|
host: str
|
||||||
host_dgpu: str,
|
|
||||||
db_host: str,
|
|
||||||
db_pass: str
|
|
||||||
):
|
):
|
||||||
async def _run_skynet():
|
async def _run_skynet():
|
||||||
async with run_skynet(
|
async with run_skynet(
|
||||||
db_host=db_host,
|
rpc_address=host
|
||||||
db_pass=db_pass,
|
|
||||||
rpc_address=host,
|
|
||||||
dgpu_address=host_dgpu
|
|
||||||
):
|
):
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
trio_asyncio.run(_run_skynet)
|
trio.run(_run_skynet)
|
||||||
|
|
||||||
|
|
||||||
@run.command()
|
@run.command()
|
||||||
|
@ -115,9 +109,9 @@ def brain(
|
||||||
@click.option(
|
@click.option(
|
||||||
'--uid', '-u', required=True)
|
'--uid', '-u', required=True)
|
||||||
@click.option(
|
@click.option(
|
||||||
'--key', '-k', default='dgpu')
|
'--key', '-k', default='dgpu.key')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--cert', '-c', default='whitelist/dgpu')
|
'--cert', '-c', default='whitelist/dgpu.cert')
|
||||||
@click.option(
|
@click.option(
|
||||||
'--algos', '-a', default=json.dumps(['midj']))
|
'--algos', '-a', default=json.dumps(['midj']))
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -159,11 +153,11 @@ def telegram(
|
||||||
cert: str,
|
cert: str,
|
||||||
rpc: str
|
rpc: str
|
||||||
):
|
):
|
||||||
assert 'TG_TOKEN' in os.environ
|
_, _, tg_token, cfg = init_env_from_config()
|
||||||
trio_asyncio.run(
|
trio_asyncio.run(
|
||||||
partial(
|
partial(
|
||||||
run_skynet_telegram,
|
run_skynet_telegram,
|
||||||
os.environ['TG_TOKEN'],
|
tg_token,
|
||||||
key_name=key,
|
key_name=key,
|
||||||
cert_name=cert,
|
cert_name=cert,
|
||||||
rpc_address=rpc
|
rpc_address=rpc
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from configparser import ConfigParser
|
||||||
|
|
||||||
|
from .constants import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def load_skynet_ini(
|
||||||
|
file_path=DEFAULT_CONFIG_PATH
|
||||||
|
):
|
||||||
|
config = ConfigParser()
|
||||||
|
config.read(file_path)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def init_env_from_config(
|
||||||
|
file_path=DEFAULT_CONFIG_PATH
|
||||||
|
):
|
||||||
|
config = load_skynet_ini()
|
||||||
|
|
||||||
|
if 'HF_TOKEN' in os.environ:
|
||||||
|
hf_token = os.environ['HF_TOKEN']
|
||||||
|
else:
|
||||||
|
hf_token = config['skynet.dgpu']['hf_token']
|
||||||
|
|
||||||
|
if 'HF_HOME' in os.environ:
|
||||||
|
hf_home = os.environ['HF_HOME']
|
||||||
|
else:
|
||||||
|
hf_home = config['skynet.dgpu']['hf_home']
|
||||||
|
|
||||||
|
if 'TG_TOKEN' in os.environ:
|
||||||
|
tg_token = os.environ['TG_TOKEN']
|
||||||
|
else:
|
||||||
|
tg_token = config['skynet.telegram']['token']
|
||||||
|
|
||||||
|
return hf_home, hf_token, tg_token, config
|
|
@ -1,14 +1,9 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
VERSION = '0.1a8'
|
VERSION = '0.1a9'
|
||||||
|
|
||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||||
|
|
||||||
DB_HOST = 'localhost:5432'
|
|
||||||
DB_USER = 'skynet'
|
|
||||||
DB_PASS = 'password'
|
|
||||||
DB_NAME = 'skynet'
|
|
||||||
|
|
||||||
ALGOS = {
|
ALGOS = {
|
||||||
'midj': 'prompthero/openjourney',
|
'midj': 'prompthero/openjourney',
|
||||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
@ -118,6 +113,7 @@ DEFAULT_ALGO = 'midj'
|
||||||
DEFAULT_ROLE = 'pleb'
|
DEFAULT_ROLE = 'pleb'
|
||||||
DEFAULT_UPSCALER = None
|
DEFAULT_UPSCALER = None
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_PATH = 'skynet.ini'
|
||||||
DEFAULT_CERTS_DIR = 'certs'
|
DEFAULT_CERTS_DIR = 'certs'
|
||||||
DEFAULT_CERT_WHITELIST_DIR = 'whitelist'
|
DEFAULT_CERT_WHITELIST_DIR = 'whitelist'
|
||||||
DEFAULT_CERT_SKYNET_PUB = 'brain.cert'
|
DEFAULT_CERT_SKYNET_PUB = 'brain.cert'
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
from .proxy import open_database_connection
|
||||||
|
|
||||||
|
from .functions import open_new_database
|
|
@ -1,18 +1,21 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import string
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import contextmanager as cm
|
||||||
|
|
||||||
import trio
|
import docker
|
||||||
import triopg
|
import psycopg2
|
||||||
import trio_asyncio
|
|
||||||
|
|
||||||
from asyncpg.exceptions import UndefinedColumnError
|
from asyncpg.exceptions import UndefinedColumnError
|
||||||
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||||
|
|
||||||
from .constants import *
|
from ..constants import *
|
||||||
|
|
||||||
|
|
||||||
DB_INIT_SQL = '''
|
DB_INIT_SQL = '''
|
||||||
|
@ -75,29 +78,67 @@ def try_decode_uid(uid: str):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@cm
|
||||||
async def open_database_connection(
|
def open_new_database():
|
||||||
db_user: str = DB_USER,
|
rpassword = ''.join(
|
||||||
db_pass: str = DB_PASS,
|
random.choice(string.ascii_lowercase)
|
||||||
db_host: str = DB_HOST,
|
for i in range(12))
|
||||||
db_name: str = DB_NAME
|
password = ''.join(
|
||||||
):
|
random.choice(string.ascii_lowercase)
|
||||||
async with trio_asyncio.open_loop() as loop:
|
for i in range(12))
|
||||||
async with triopg.create_pool(
|
|
||||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
|
||||||
) as pool_conn:
|
|
||||||
async with pool_conn.acquire() as conn:
|
|
||||||
res = await conn.execute(f'''
|
|
||||||
select distinct table_schema
|
|
||||||
from information_schema.tables
|
|
||||||
where table_schema = \'{db_name}\'
|
|
||||||
''')
|
|
||||||
if '1' in res:
|
|
||||||
logging.info('schema already in db, skipping init')
|
|
||||||
else:
|
|
||||||
await conn.execute(DB_INIT_SQL)
|
|
||||||
|
|
||||||
yield pool_conn
|
dclient = docker.from_env()
|
||||||
|
|
||||||
|
container = dclient.containers.run(
|
||||||
|
'postgres',
|
||||||
|
name='skynet-test-postgres',
|
||||||
|
ports={'5432/tcp': None},
|
||||||
|
environment={
|
||||||
|
'POSTGRES_PASSWORD': rpassword
|
||||||
|
},
|
||||||
|
detach=True,
|
||||||
|
remove=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for log in container.logs(stream=True):
|
||||||
|
log = log.decode().rstrip()
|
||||||
|
logging.info(log)
|
||||||
|
if ('database system is ready to accept connections' in log or
|
||||||
|
'database system is shut down' in log):
|
||||||
|
break
|
||||||
|
|
||||||
|
# ip = container.attrs['NetworkSettings']['IPAddress']
|
||||||
|
container.reload()
|
||||||
|
port = container.ports['5432/tcp'][0]['HostPort']
|
||||||
|
host = f'localhost:{port}'
|
||||||
|
|
||||||
|
# why print the system is ready to accept connections when its not
|
||||||
|
# postgres? wtf
|
||||||
|
time.sleep(1)
|
||||||
|
logging.info('creating skynet db...')
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
user='postgres',
|
||||||
|
password=rpassword,
|
||||||
|
host='localhost',
|
||||||
|
port=port
|
||||||
|
)
|
||||||
|
logging.info('connected...')
|
||||||
|
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
f'CREATE USER skynet WITH PASSWORD \'{password}\'')
|
||||||
|
cursor.execute(
|
||||||
|
f'CREATE DATABASE skynet')
|
||||||
|
cursor.execute(
|
||||||
|
f'GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet')
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logging.info('done.')
|
||||||
|
yield container, password, host
|
||||||
|
|
||||||
|
container.stop()
|
||||||
|
|
||||||
|
|
||||||
async def get_user(conn, uid: str):
|
async def get_user(conn, uid: str):
|
|
@ -0,0 +1,123 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import tractor
|
||||||
|
import asyncpg
|
||||||
|
import asyncio
|
||||||
|
import trio_asyncio
|
||||||
|
|
||||||
|
|
||||||
|
_spawn_kwargs = {
|
||||||
|
'infect_asyncio': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def aio_db_proxy(
|
||||||
|
to_trio: trio.MemorySendChannel,
|
||||||
|
from_trio: asyncio.Queue,
|
||||||
|
db_user: str = 'skynet',
|
||||||
|
db_pass: str = 'password',
|
||||||
|
db_host: str = 'localhost:5432',
|
||||||
|
db_name: str = 'skynet'
|
||||||
|
) -> None:
|
||||||
|
db = importlib.import_module('skynet.db.functions')
|
||||||
|
|
||||||
|
pool = await asyncpg.create_pool(
|
||||||
|
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}')
|
||||||
|
|
||||||
|
async with pool_conn.acquire() as conn:
|
||||||
|
res = await conn.execute(f'''
|
||||||
|
select distinct table_schema
|
||||||
|
from information_schema.tables
|
||||||
|
where table_schema = \'{db_name}\'
|
||||||
|
''')
|
||||||
|
if '1' in res:
|
||||||
|
logging.info('schema already in db, skipping init')
|
||||||
|
else:
|
||||||
|
await conn.execute(DB_INIT_SQL)
|
||||||
|
|
||||||
|
# a first message must be sent **from** this ``asyncio``
|
||||||
|
# task or the ``trio`` side will never unblock from
|
||||||
|
# ``tractor.to_asyncio.open_channel_from():``
|
||||||
|
to_trio.send_nowait('start')
|
||||||
|
|
||||||
|
# XXX: this uses an ``from_trio: asyncio.Queue`` currently but we
|
||||||
|
# should probably offer something better.
|
||||||
|
while True:
|
||||||
|
msg = await from_trio.get()
|
||||||
|
|
||||||
|
method = getattr(db, msg.get('method'))
|
||||||
|
args = getattr(db, msg.get('args', []))
|
||||||
|
kwargs = getattr(db, msg.get('kwargs', {}))
|
||||||
|
|
||||||
|
async with pool_conn.acquire() as conn:
|
||||||
|
result = await method(conn, *args, **kwargs)
|
||||||
|
to_trio.send_nowait(result)
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def trio_to_aio_db_proxy(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
db_user: str = 'skynet',
|
||||||
|
db_pass: str = 'password',
|
||||||
|
db_host: str = 'localhost:5432',
|
||||||
|
db_name: str = 'skynet'
|
||||||
|
):
|
||||||
|
# this will block until the ``asyncio`` task sends a "first"
|
||||||
|
# message.
|
||||||
|
async with tractor.to_asyncio.open_channel_from(
|
||||||
|
aio_db_proxy,
|
||||||
|
db_user=db_user,
|
||||||
|
db_pass=db_pass,
|
||||||
|
db_host=db_host,
|
||||||
|
db_name=db_name
|
||||||
|
) as (first, chan):
|
||||||
|
|
||||||
|
assert first == 'start'
|
||||||
|
await ctx.started(first)
|
||||||
|
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
|
||||||
|
async for msg in stream:
|
||||||
|
await chan.send(msg)
|
||||||
|
|
||||||
|
out = await chan.receive()
|
||||||
|
# echo back to parent actor-task
|
||||||
|
await stream.send(out)
|
||||||
|
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open_database_connection(
|
||||||
|
db_user: str = 'skynet',
|
||||||
|
db_pass: str = 'password',
|
||||||
|
db_host: str = 'localhost:5432',
|
||||||
|
db_name: str = 'skynet'
|
||||||
|
):
|
||||||
|
async with tractor.open_nursery() as n:
|
||||||
|
p = await n.start_actor(
|
||||||
|
'aio_db_proxy',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
infect_asyncio=True,
|
||||||
|
)
|
||||||
|
async with p.open_context(
|
||||||
|
trio_to_aio_db_proxy,
|
||||||
|
db_user=db_user,
|
||||||
|
db_pass=db_pass,
|
||||||
|
db_host=db_host,
|
||||||
|
db_name=db_name
|
||||||
|
) as (ctx, first):
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
|
||||||
|
async def _db_pc(method: str, *args, **kwargs):
|
||||||
|
await stream.send({
|
||||||
|
'method': method,
|
||||||
|
'args': args,
|
||||||
|
'kwargs': kwargs
|
||||||
|
})
|
||||||
|
return await stream.receive()
|
||||||
|
|
||||||
|
yield _db_pc
|
405
skynet/dgpu.py
405
skynet/dgpu.py
|
@ -2,29 +2,17 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import io
|
import io
|
||||||
import trio
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import zlib
|
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import ExitStack
|
|
||||||
|
|
||||||
import pynng
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pynng import TLSConfig
|
from pynng import Context
|
||||||
from OpenSSL.crypto import (
|
|
||||||
load_privatekey,
|
|
||||||
load_certificate,
|
|
||||||
FILETYPE_PEM
|
|
||||||
)
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
|
@ -34,12 +22,9 @@ from realesrgan import RealESRGANer
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from .utils import (
|
from .utils import *
|
||||||
pipeline_for,
|
from .network import *
|
||||||
convert_from_cv2_to_image, convert_from_image_to_cv2
|
|
||||||
)
|
|
||||||
from .protobuf import *
|
from .protobuf import *
|
||||||
from .frontend import open_skynet_rpc
|
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,65 +49,16 @@ class DGPUComputeError(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class ReconnectingBus:
|
|
||||||
|
|
||||||
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
|
||||||
self.address = address
|
|
||||||
self.tls_config = tls_config
|
|
||||||
|
|
||||||
self._stack = ExitStack()
|
|
||||||
self._sock = None
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
self._sock = self._stack.enter_context(
|
|
||||||
pynng.Bus0(recv_max_size=0))
|
|
||||||
self._sock.tls_config = self.tls_config
|
|
||||||
self._sock.dial(self.address)
|
|
||||||
self._closed = False
|
|
||||||
|
|
||||||
async def arecv(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
return await self._sock.arecv()
|
|
||||||
|
|
||||||
except pynng.exceptions.Closed:
|
|
||||||
if self._closed:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def asend(self, msg):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
return await self._sock.asend(msg)
|
|
||||||
|
|
||||||
except pynng.exceptions.Closed:
|
|
||||||
if self._closed:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self._stack.close()
|
|
||||||
self._stack = ExitStack()
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
def reconnect(self):
|
|
||||||
self.close()
|
|
||||||
self.connect()
|
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(
|
async def open_dgpu_node(
|
||||||
cert_name: str,
|
cert_name: str,
|
||||||
unique_id: str,
|
unique_id: str,
|
||||||
key_name: Optional[str],
|
key_name: Optional[str],
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||||
initial_algos: Optional[List[str]] = None,
|
initial_algos: Optional[List[str]] = None
|
||||||
security: bool = True
|
|
||||||
):
|
):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logging.info(f'starting dgpu node!')
|
logging.info(f'starting dgpu node!')
|
||||||
|
|
||||||
name = uuid.uuid4()
|
|
||||||
|
|
||||||
logging.info(f'loading models...')
|
logging.info(f'loading models...')
|
||||||
|
|
||||||
upscaler = init_upscaler()
|
upscaler = init_upscaler()
|
||||||
|
@ -141,241 +77,140 @@ async def open_dgpu_node(
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
logging.info('\n' + torch.cuda.memory_summary())
|
logging.info('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
async def gpu_compute_one(ireq: DiffusionParameters, image=None):
|
async def gpu_compute_one(method: str, params: dict, binext: Optional[bytes] = None):
|
||||||
algo = ireq.algo + 'img' if image else ireq.algo
|
match method:
|
||||||
if algo not in models:
|
case 'diffuse':
|
||||||
least_used = list(models.keys())[0]
|
image = None
|
||||||
for model in models:
|
algo = params['algo']
|
||||||
if models[least_used]['generated'] > models[model]['generated']:
|
if binext:
|
||||||
least_used = model
|
algo += 'img'
|
||||||
|
image = Image.open(io.BytesIO(binext))
|
||||||
|
w, h = image.size
|
||||||
|
logging.info(f'user sent img of size {image.size}')
|
||||||
|
|
||||||
del models[least_used]
|
if w > 512 or h > 512:
|
||||||
gc.collect()
|
image.thumbnail((512, 512))
|
||||||
|
logging.info(f'resized it to {image.size}')
|
||||||
|
|
||||||
models[algo] = {
|
if algo not in models:
|
||||||
'pipe': pipeline_for(ireq.algo, image=True if image else False),
|
logging.info(f'{algo} not in loaded models, swapping...')
|
||||||
'generated': 0
|
least_used = list(models.keys())[0]
|
||||||
}
|
for model in models:
|
||||||
|
if models[least_used]['generated'] > models[model]['generated']:
|
||||||
|
least_used = model
|
||||||
|
|
||||||
_params = {}
|
del models[least_used]
|
||||||
if ireq.image:
|
gc.collect()
|
||||||
_params['image'] = image
|
|
||||||
_params['strength'] = ireq.strength
|
|
||||||
|
|
||||||
else:
|
models[algo] = {
|
||||||
_params['width'] = int(ireq.width)
|
'pipe': pipeline_for(params['algo'], image=True if binext else False),
|
||||||
_params['height'] = int(ireq.height)
|
'generated': 0
|
||||||
|
}
|
||||||
|
logging.info(f'swapping done.')
|
||||||
|
|
||||||
try:
|
_params = {}
|
||||||
image = models[algo]['pipe'](
|
logging.info(method)
|
||||||
ireq.prompt,
|
logging.info(json.dumps(params, indent=4))
|
||||||
**_params,
|
logging.info(f'binext: {len(binext) if binext else 0} bytes')
|
||||||
guidance_scale=ireq.guidance,
|
if binext:
|
||||||
num_inference_steps=int(ireq.step),
|
_params['image'] = image
|
||||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
_params['strength'] = params['strength']
|
||||||
).images[0]
|
|
||||||
|
|
||||||
if ireq.upscaler == 'x4':
|
else:
|
||||||
logging.info(f'size: {len(image.tobytes())}')
|
_params['width'] = int(params['width'])
|
||||||
logging.info('performing upscale...')
|
_params['height'] = int(params['height'])
|
||||||
input_img = image.convert('RGB')
|
|
||||||
up_img, _ = upscaler.enhance(
|
|
||||||
convert_from_image_to_cv2(input_img), outscale=4)
|
|
||||||
|
|
||||||
image = convert_from_cv2_to_image(up_img)
|
try:
|
||||||
logging.info('done')
|
image = models[algo]['pipe'](
|
||||||
|
params['prompt'],
|
||||||
|
**_params,
|
||||||
|
guidance_scale=params['guidance'],
|
||||||
|
num_inference_steps=int(params['step']),
|
||||||
|
generator=torch.Generator("cuda").manual_seed(
|
||||||
|
int(params['seed']) if params['seed'] else random.randint(0, 2 ** 64)
|
||||||
|
)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
img_byte_arr = io.BytesIO()
|
if params['upscaler'] == 'x4':
|
||||||
image.save(img_byte_arr, format='PNG')
|
logging.info(f'size: {len(image.tobytes())}')
|
||||||
raw_img = img_byte_arr.getvalue()
|
logging.info('performing upscale...')
|
||||||
logging.info(f'final img size {len(raw_img)} bytes.')
|
input_img = image.convert('RGB')
|
||||||
|
up_img, _ = upscaler.enhance(
|
||||||
|
convert_from_image_to_cv2(input_img), outscale=4)
|
||||||
|
|
||||||
return raw_img
|
image = convert_from_cv2_to_image(up_img)
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
except BaseException as e:
|
img_byte_arr = io.BytesIO()
|
||||||
logging.error(e)
|
image.save(img_byte_arr, format='PNG')
|
||||||
raise DGPUComputeError(str(e))
|
raw_img = img_byte_arr.getvalue()
|
||||||
|
logging.info(f'final img size {len(raw_img)} bytes.')
|
||||||
|
|
||||||
finally:
|
return raw_img
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
except BaseException as e:
|
||||||
|
logging.error(e)
|
||||||
|
raise DGPUComputeError(str(e))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise DGPUComputeError('Unsupported compute method')
|
||||||
|
|
||||||
|
async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
|
||||||
|
result = {}
|
||||||
|
resp = SkynetRPCResponse()
|
||||||
|
|
||||||
|
match req.method:
|
||||||
|
case 'dgpu_time':
|
||||||
|
result = {'ok': time_ms()}
|
||||||
|
|
||||||
|
case _:
|
||||||
|
logging.debug(f'dgpu got one request: {req.method}')
|
||||||
|
try:
|
||||||
|
resp.bin = await gpu_compute_one(
|
||||||
|
req.method, MessageToDict(req.params),
|
||||||
|
binext=req.bin if req.bin else None
|
||||||
|
)
|
||||||
|
logging.debug(f'dgpu processed one request')
|
||||||
|
|
||||||
|
except DGPUComputeError as e:
|
||||||
|
result = {'error': str(e)}
|
||||||
|
|
||||||
|
resp.result.update(result)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
rpc_server = SessionServer(
|
||||||
|
dgpu_address,
|
||||||
|
rpc_handler,
|
||||||
|
cert_name=cert_name,
|
||||||
|
key_name=key_name
|
||||||
|
)
|
||||||
|
skynet_rpc = SessionClient(
|
||||||
|
rpc_address,
|
||||||
|
unique_id,
|
||||||
|
cert_name=cert_name,
|
||||||
|
key_name=key_name
|
||||||
|
)
|
||||||
|
skynet_rpc.connect()
|
||||||
|
|
||||||
|
|
||||||
async with (
|
async with rpc_server.open() as rpc_server:
|
||||||
open_skynet_rpc(
|
res = await skynet_rpc.rpc(
|
||||||
unique_id,
|
'dgpu_online', {
|
||||||
rpc_address=rpc_address,
|
'dgpu_addr': rpc_server.addr,
|
||||||
security=security,
|
'cert': cert_name
|
||||||
cert_name=cert_name,
|
})
|
||||||
key_name=key_name
|
|
||||||
) as rpc_call,
|
|
||||||
trio.open_nursery() as n
|
|
||||||
):
|
|
||||||
|
|
||||||
tls_config = None
|
|
||||||
if security:
|
|
||||||
# load tls certs
|
|
||||||
if not key_name:
|
|
||||||
key_name = cert_name
|
|
||||||
|
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
|
||||||
|
|
||||||
skynet_cert_path = certs_dir / 'brain.cert'
|
|
||||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
|
||||||
tls_key_path = certs_dir / f'{key_name}.key'
|
|
||||||
|
|
||||||
cert_name = tls_cert_path.stem
|
|
||||||
|
|
||||||
skynet_cert_data = skynet_cert_path.read_text()
|
|
||||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
|
||||||
|
|
||||||
tls_cert_data = tls_cert_path.read_text()
|
|
||||||
|
|
||||||
tls_key_data = tls_key_path.read_text()
|
|
||||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
|
||||||
|
|
||||||
logging.info(f'skynet cert: {skynet_cert_path}')
|
|
||||||
logging.info(f'dgpu cert: {tls_cert_path}')
|
|
||||||
logging.info(f'dgpu key: {tls_key_path}')
|
|
||||||
|
|
||||||
dgpu_address = 'tls+' + dgpu_address
|
|
||||||
tls_config = TLSConfig(
|
|
||||||
TLSConfig.MODE_CLIENT,
|
|
||||||
own_key_string=tls_key_data,
|
|
||||||
own_cert_string=tls_cert_data,
|
|
||||||
ca_string=skynet_cert_data)
|
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
|
||||||
|
|
||||||
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
|
||||||
dgpu_bus.connect()
|
|
||||||
|
|
||||||
last_msg = time.time()
|
|
||||||
async def connection_refresher(refresh_time: int = 120):
|
|
||||||
nonlocal last_msg
|
|
||||||
while True:
|
|
||||||
now = time.time()
|
|
||||||
last_msg_time_delta = now - last_msg
|
|
||||||
logging.info(f'time since last msg: {last_msg_time_delta}')
|
|
||||||
if last_msg_time_delta > refresh_time:
|
|
||||||
dgpu_bus.reconnect()
|
|
||||||
logging.info('reconnected!')
|
|
||||||
last_msg = now
|
|
||||||
|
|
||||||
await trio.sleep(refresh_time)
|
|
||||||
|
|
||||||
n.start_soon(connection_refresher)
|
|
||||||
|
|
||||||
res = await rpc_call('dgpu_online')
|
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
await trio.sleep_forever()
|
||||||
msg = await dgpu_bus.arecv()
|
|
||||||
|
|
||||||
img = None
|
|
||||||
if b'BINEXT' in msg:
|
|
||||||
header, msg, img_raw = msg.split(b'%$%$')
|
|
||||||
logging.info(f'got img attachment of size {len(img_raw)}')
|
|
||||||
logging.info(img_raw[:10])
|
|
||||||
raw_img = zlib.decompress(img_raw)
|
|
||||||
logging.info(raw_img[:10])
|
|
||||||
img = Image.open(io.BytesIO(raw_img))
|
|
||||||
w, h = img.size
|
|
||||||
logging.info(f'user sent img of size {img.size}')
|
|
||||||
|
|
||||||
if w > 512 or h > 512:
|
|
||||||
img.thumbnail((512, 512))
|
|
||||||
logging.info(f'resized it to {img.size}')
|
|
||||||
|
|
||||||
|
|
||||||
req = DGPUBusMessage()
|
|
||||||
req.ParseFromString(msg)
|
|
||||||
last_msg = time.time()
|
|
||||||
|
|
||||||
if req.method == 'heartbeat':
|
|
||||||
rep = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=unique_id,
|
|
||||||
method=req.method
|
|
||||||
)
|
|
||||||
rep.params.update({'time': int(time.time() * 1000)})
|
|
||||||
|
|
||||||
if security:
|
|
||||||
rep.auth.cert = cert_name
|
|
||||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
|
||||||
|
|
||||||
await dgpu_bus.asend(rep.SerializeToString())
|
|
||||||
logging.info('heartbeat reply')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if req.nid != unique_id:
|
|
||||||
logging.info(
|
|
||||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if security:
|
|
||||||
verify_protobuf_msg(req, skynet_cert)
|
|
||||||
|
|
||||||
|
|
||||||
ack_resp = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=req.nid
|
|
||||||
)
|
|
||||||
ack_resp.params.update({'ack': {}})
|
|
||||||
|
|
||||||
if security:
|
|
||||||
ack_resp.auth.cert = cert_name
|
|
||||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
|
||||||
|
|
||||||
# send ack
|
|
||||||
await dgpu_bus.asend(ack_resp.SerializeToString())
|
|
||||||
|
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
|
||||||
|
|
||||||
try:
|
|
||||||
img_req = DiffusionParameters(**req.params)
|
|
||||||
|
|
||||||
if not img_req.seed:
|
|
||||||
img_req.seed = random.randint(0, 2 ** 64)
|
|
||||||
|
|
||||||
img = await gpu_compute_one(img_req, image=img)
|
|
||||||
img_resp = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=req.nid,
|
|
||||||
method='binary-reply'
|
|
||||||
)
|
|
||||||
img_resp.params.update({
|
|
||||||
'len': len(img),
|
|
||||||
'meta': img_req.to_dict()
|
|
||||||
})
|
|
||||||
|
|
||||||
except DGPUComputeError as e:
|
|
||||||
traceback.print_exception(type(e), e, e.__traceback__)
|
|
||||||
img_resp = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=req.nid
|
|
||||||
)
|
|
||||||
img_resp.params.update({'error': str(e)})
|
|
||||||
|
|
||||||
|
|
||||||
if security:
|
|
||||||
img_resp.auth.cert = cert_name
|
|
||||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
|
||||||
|
|
||||||
# send final image
|
|
||||||
logging.info('sending img back...')
|
|
||||||
raw_msg = img_resp.SerializeToString()
|
|
||||||
await dgpu_bus.asend(raw_msg)
|
|
||||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
|
||||||
if img_resp.method == 'binary-reply':
|
|
||||||
await dgpu_bus.asend(zlib.compress(img))
|
|
||||||
logging.info(f'sent {len(img)} bytes.')
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info('interrupt caught, stopping...')
|
logging.info('interrupt caught, stopping...')
|
||||||
n.cancel_scope.cancel()
|
|
||||||
dgpu_bus.close()
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
res = await rpc_call('dgpu_offline')
|
res = await skynet_rpc.rpc('dgpu_offline')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
|
@ -4,7 +4,7 @@ import json
|
||||||
|
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import asynccontextmanager as acm
|
from contextlib import contextmanager as cm
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from OpenSSL.crypto import (
|
||||||
|
|
||||||
from google.protobuf.struct_pb2 import Struct
|
from google.protobuf.struct_pb2 import Struct
|
||||||
|
|
||||||
|
from ..network import SessionClient
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
from ..protobuf.auth import *
|
from ..protobuf.auth import *
|
||||||
|
@ -39,75 +40,23 @@ class ConfigSizeDivisionByEight(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@acm
|
@cm
|
||||||
async def open_skynet_rpc(
|
def open_skynet_rpc(
|
||||||
unique_id: str,
|
unique_id: str,
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
security: bool = False,
|
|
||||||
cert_name: Optional[str] = None,
|
cert_name: Optional[str] = None,
|
||||||
key_name: Optional[str] = None
|
key_name: Optional[str] = None
|
||||||
):
|
):
|
||||||
tls_config = None
|
sesh = SessionClient(
|
||||||
|
rpc_address,
|
||||||
if security:
|
unique_id,
|
||||||
# load tls certs
|
cert_name=cert_name,
|
||||||
if not key_name:
|
key_name=key_name
|
||||||
key_name = cert_name
|
)
|
||||||
|
logging.debug(f'opening skynet rpc...')
|
||||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
sesh.connect()
|
||||||
|
yield sesh
|
||||||
skynet_cert_data = (certs_dir / 'brain.cert').read_text()
|
sesh.disconnect()
|
||||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
|
||||||
|
|
||||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
|
||||||
tls_cert_data = tls_cert_path.read_text()
|
|
||||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
|
||||||
cert_name = tls_cert_path.stem
|
|
||||||
|
|
||||||
tls_key_data = (certs_dir / f'{key_name}.key').read_text()
|
|
||||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
|
||||||
|
|
||||||
rpc_address = 'tls+' + rpc_address
|
|
||||||
tls_config = TLSConfig(
|
|
||||||
TLSConfig.MODE_CLIENT,
|
|
||||||
own_key_string=tls_key_data,
|
|
||||||
own_cert_string=tls_cert_data,
|
|
||||||
ca_string=skynet_cert_data)
|
|
||||||
|
|
||||||
with pynng.Req0(recv_max_size=0) as sock:
|
|
||||||
if security:
|
|
||||||
sock.tls_config = tls_config
|
|
||||||
|
|
||||||
sock.dial(rpc_address)
|
|
||||||
|
|
||||||
async def _rpc_call(
|
|
||||||
method: str,
|
|
||||||
params: dict = {},
|
|
||||||
uid: Optional[str] = None
|
|
||||||
):
|
|
||||||
req = SkynetRPCRequest()
|
|
||||||
req.uid = uid if uid else unique_id
|
|
||||||
req.method = method
|
|
||||||
req.params.update(params)
|
|
||||||
|
|
||||||
if security:
|
|
||||||
req.auth.cert = cert_name
|
|
||||||
req.auth.sig = sign_protobuf_msg(req, tls_key)
|
|
||||||
|
|
||||||
ctx = sock.new_context()
|
|
||||||
await ctx.asend(req.SerializeToString())
|
|
||||||
|
|
||||||
resp = SkynetRPCResponse()
|
|
||||||
resp.ParseFromString(await ctx.arecv())
|
|
||||||
ctx.close()
|
|
||||||
|
|
||||||
if security:
|
|
||||||
verify_protobuf_msg(resp, skynet_cert)
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
yield _rpc_call
|
|
||||||
|
|
||||||
|
|
||||||
def validate_user_config_request(req: str):
|
def validate_user_config_request(req: str):
|
||||||
params = req.split(' ')
|
params = req.split(' ')
|
||||||
|
|
|
@ -6,8 +6,6 @@ import logging
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pynng
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from trio_asyncio import aio_as_trio
|
from trio_asyncio import aio_as_trio
|
||||||
|
|
||||||
|
@ -16,6 +14,7 @@ from telebot.types import (
|
||||||
)
|
)
|
||||||
from telebot.async_telebot import AsyncTeleBot
|
from telebot.async_telebot import AsyncTeleBot
|
||||||
|
|
||||||
|
from ..db import open_database_connection
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
from . import *
|
from . import *
|
||||||
|
@ -56,228 +55,274 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def run_skynet_telegram(
|
async def run_skynet_telegram(
|
||||||
|
name: str,
|
||||||
tg_token: str,
|
tg_token: str,
|
||||||
key_name: str = 'telegram-frontend',
|
key_name: str = 'telegram-frontend.key',
|
||||||
cert_name: str = 'whitelist/telegram-frontend',
|
cert_name: str = 'whitelist/telegram-frontend.cert',
|
||||||
rpc_address: str = DEFAULT_RPC_ADDR
|
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||||
|
db_host: str = 'localhost:5432',
|
||||||
|
db_user: str = 'skynet',
|
||||||
|
db_pass: str = 'password'
|
||||||
):
|
):
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
bot = AsyncTeleBot(tg_token)
|
bot = AsyncTeleBot(tg_token)
|
||||||
|
logging.info(f'tg_token: {tg_token}')
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with open_database_connection(
|
||||||
'skynet-telegram-0',
|
db_user, db_pass, db_host
|
||||||
rpc_address=rpc_address,
|
) as db_call:
|
||||||
security=True,
|
with open_skynet_rpc(
|
||||||
cert_name=cert_name,
|
f'skynet-telegram-{name}',
|
||||||
key_name=key_name
|
rpc_address=rpc_address,
|
||||||
) as rpc_call:
|
cert_name=cert_name,
|
||||||
|
key_name=key_name
|
||||||
|
) as session:
|
||||||
|
|
||||||
async def _rpc_call(
|
@bot.message_handler(commands=['help'])
|
||||||
uid: int,
|
async def send_help(message):
|
||||||
method: str,
|
splt_msg = message.text.split(' ')
|
||||||
params: dict = {}
|
|
||||||
):
|
|
||||||
return await rpc_call(
|
|
||||||
method, params, uid=f'{PREFIX}+{uid}')
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['help'])
|
if len(splt_msg) == 1:
|
||||||
async def send_help(message):
|
await bot.reply_to(message, HELP_TEXT)
|
||||||
splt_msg = message.text.split(' ')
|
|
||||||
|
|
||||||
if len(splt_msg) == 1:
|
|
||||||
await bot.reply_to(message, HELP_TEXT)
|
|
||||||
|
|
||||||
else:
|
|
||||||
param = splt_msg[1]
|
|
||||||
if param in HELP_TOPICS:
|
|
||||||
await bot.reply_to(message, HELP_TOPICS[param])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
param = splt_msg[1]
|
||||||
|
if param in HELP_TOPICS:
|
||||||
|
await bot.reply_to(message, HELP_TOPICS[param])
|
||||||
|
|
||||||
@bot.message_handler(commands=['cool'])
|
else:
|
||||||
async def send_cool_words(message):
|
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['txt2img'])
|
@bot.message_handler(commands=['cool'])
|
||||||
async def send_txt2img(message):
|
async def send_cool_words(message):
|
||||||
chat = message.chat
|
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||||
|
|
||||||
prompt = ' '.join(message.text.split(' ')[1:])
|
@bot.message_handler(commands=['txt2img'])
|
||||||
|
async def send_txt2img(message):
|
||||||
|
chat = message.chat
|
||||||
|
reply_id = None
|
||||||
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||||
|
reply_id = message.message_id
|
||||||
|
|
||||||
if len(prompt) == 0:
|
user_id = f'tg+{message.from_user.id}'
|
||||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info(f'mid: {message.id}')
|
prompt = ' '.join(message.text.split(' ')[1:])
|
||||||
resp = await _rpc_call(
|
|
||||||
message.from_user.id,
|
|
||||||
'txt2img',
|
|
||||||
{'prompt': prompt}
|
|
||||||
)
|
|
||||||
logging.info(f'resp to {message.id} arrived')
|
|
||||||
|
|
||||||
resp_txt = ''
|
if len(prompt) == 0:
|
||||||
result = MessageToDict(resp.result)
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||||
if 'error' in resp.result:
|
return
|
||||||
resp_txt = resp.result['message']
|
|
||||||
|
|
||||||
else:
|
logging.info(f'mid: {message.id}')
|
||||||
logging.info(result['id'])
|
user = await db_call('get_or_create_user', user_id)
|
||||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
user_config = {**(await db_call('get_user_config', user))}
|
||||||
logging.info(f'got image of size: {len(img_raw)}')
|
del user_config['id']
|
||||||
img = Image.open(io.BytesIO(img_raw))
|
|
||||||
|
|
||||||
await bot.send_photo(
|
resp = await session.rpc(
|
||||||
GROUP_ID,
|
'dgpu_call', {
|
||||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
'method': 'diffuse',
|
||||||
photo=img,
|
'params': {
|
||||||
reply_markup=build_redo_menu()
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
},
|
||||||
|
timeout=60
|
||||||
)
|
)
|
||||||
return
|
logging.info(f'resp to {message.id} arrived')
|
||||||
|
|
||||||
await bot.reply_to(message, resp_txt)
|
resp_txt = ''
|
||||||
|
result = MessageToDict(resp.result)
|
||||||
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
else:
|
||||||
async def send_img2img(message):
|
logging.info(result['id'])
|
||||||
chat = message.chat
|
img_raw = resp.bin
|
||||||
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
|
|
||||||
if not message.caption.startswith('/img2img'):
|
await bot.send_photo(
|
||||||
return
|
GROUP_ID,
|
||||||
|
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||||
|
photo=img,
|
||||||
|
reply_to_message_id=reply_id,
|
||||||
|
reply_markup=build_redo_menu()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
|
||||||
|
|
||||||
if len(prompt) == 0:
|
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
||||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
async def send_img2img(message):
|
||||||
return
|
chat = message.chat
|
||||||
|
reply_id = None
|
||||||
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||||
|
reply_id = message.message_id
|
||||||
|
|
||||||
file_id = message.photo[-1].file_id
|
user_id = f'tg+{message.from_user.id}'
|
||||||
file_path = (await bot.get_file(file_id)).file_path
|
|
||||||
file_raw = await bot.download_file(file_path)
|
|
||||||
img = zlib.compress(file_raw)
|
|
||||||
|
|
||||||
logging.info(f'mid: {message.id}')
|
if not message.caption.startswith('/img2img'):
|
||||||
resp = await _rpc_call(
|
await bot.reply_to(
|
||||||
message.from_user.id,
|
message,
|
||||||
'img2img',
|
'For image to image you need to add /img2img to the beggining of your caption'
|
||||||
{'prompt': prompt, 'img': img.hex()}
|
)
|
||||||
)
|
return
|
||||||
logging.info(f'resp to {message.id} arrived')
|
|
||||||
|
|
||||||
resp_txt = ''
|
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||||
result = MessageToDict(resp.result)
|
|
||||||
if 'error' in resp.result:
|
|
||||||
resp_txt = resp.result['message']
|
|
||||||
|
|
||||||
else:
|
if len(prompt) == 0:
|
||||||
logging.info(result['id'])
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
return
|
||||||
logging.info(f'got image of size: {len(img_raw)}')
|
|
||||||
img = Image.open(io.BytesIO(img_raw))
|
|
||||||
|
|
||||||
await bot.send_media_group(
|
file_id = message.photo[-1].file_id
|
||||||
GROUP_ID,
|
file_path = (await bot.get_file(file_id)).file_path
|
||||||
media=[
|
file_raw = await bot.download_file(file_path)
|
||||||
InputMediaPhoto(file_id),
|
|
||||||
InputMediaPhoto(
|
logging.info(f'mid: {message.id}')
|
||||||
img,
|
|
||||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
|
user = await db_call('get_or_create_user', user_id)
|
||||||
)
|
user_config = {**(await db_call('get_user_config', user))}
|
||||||
]
|
del user_config['id']
|
||||||
|
|
||||||
|
resp = await session.rpc(
|
||||||
|
'dgpu_call', {
|
||||||
|
'method': 'diffuse',
|
||||||
|
'params': {
|
||||||
|
'prompt': prompt,
|
||||||
|
**user_config
|
||||||
|
}
|
||||||
|
},
|
||||||
|
binext=file_raw,
|
||||||
|
timeout=60
|
||||||
)
|
)
|
||||||
return
|
logging.info(f'resp to {message.id} arrived')
|
||||||
|
|
||||||
await bot.reply_to(message, resp_txt)
|
resp_txt = ''
|
||||||
|
result = MessageToDict(resp.result)
|
||||||
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
@bot.message_handler(commands=['img2img'])
|
else:
|
||||||
async def redo_txt2img(message):
|
logging.info(result['id'])
|
||||||
await bot.reply_to(
|
img_raw = resp.bin
|
||||||
message,
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
'seems you tried to do an img2img command without sending image'
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
)
|
|
||||||
|
|
||||||
async def _redo(message):
|
await bot.send_media_group(
|
||||||
resp = await _rpc_call(message.from_user.id, 'redo')
|
GROUP_ID,
|
||||||
|
media=[
|
||||||
|
InputMediaPhoto(file_id),
|
||||||
|
InputMediaPhoto(
|
||||||
|
img,
|
||||||
|
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
|
||||||
|
)
|
||||||
|
],
|
||||||
|
reply_to_message_id=reply_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
resp_txt = ''
|
|
||||||
result = MessageToDict(resp.result)
|
|
||||||
if 'error' in resp.result:
|
|
||||||
resp_txt = resp.result['message']
|
|
||||||
|
|
||||||
else:
|
@bot.message_handler(commands=['img2img'])
|
||||||
logging.info(result['id'])
|
async def img2img_missing_image(message):
|
||||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
await bot.reply_to(
|
||||||
logging.info(f'got image of size: {len(img_raw)}')
|
message,
|
||||||
img = Image.open(io.BytesIO(img_raw))
|
'seems you tried to do an img2img command without sending image'
|
||||||
|
|
||||||
await bot.send_photo(
|
|
||||||
GROUP_ID,
|
|
||||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
|
||||||
photo=img,
|
|
||||||
reply_markup=build_redo_menu()
|
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
await bot.reply_to(message, resp_txt)
|
@bot.message_handler(commands=['redo'])
|
||||||
|
async def redo(message):
|
||||||
|
chat = message.chat
|
||||||
|
reply_id = None
|
||||||
|
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||||
|
reply_id = message.message_id
|
||||||
|
|
||||||
@bot.message_handler(commands=['redo'])
|
user_config = {**(await db_call('get_user_config', user))}
|
||||||
async def redo_txt2img(message):
|
del user_config['id']
|
||||||
await _redo(message)
|
prompt = await db_call('get_last_prompt_of', user)
|
||||||
|
|
||||||
@bot.message_handler(commands=['config'])
|
resp = await session.rpc(
|
||||||
async def set_config(message):
|
'dgpu_call', {
|
||||||
rpc_params = {}
|
'method': 'diffuse',
|
||||||
try:
|
'params': {
|
||||||
attr, val, reply_txt = validate_user_config_request(
|
'prompt': prompt,
|
||||||
message.text)
|
**user_config
|
||||||
|
}
|
||||||
|
},
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
logging.info(f'resp to {message.id} arrived')
|
||||||
|
|
||||||
resp = await _rpc_call(
|
resp_txt = ''
|
||||||
message.from_user.id,
|
result = MessageToDict(resp.result)
|
||||||
'config', {'attr': attr, 'val': val})
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
except BaseException as e:
|
else:
|
||||||
reply_txt = str(e)
|
logging.info(result['id'])
|
||||||
|
img_raw = resp.bin
|
||||||
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
|
|
||||||
finally:
|
await bot.send_photo(
|
||||||
await bot.reply_to(message, reply_txt)
|
GROUP_ID,
|
||||||
|
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||||
|
photo=img,
|
||||||
|
reply_to_message_id=reply_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
@bot.message_handler(commands=['stats'])
|
@bot.message_handler(commands=['config'])
|
||||||
async def user_stats(message):
|
async def set_config(message):
|
||||||
resp = await _rpc_call(
|
rpc_params = {}
|
||||||
message.from_user.id,
|
try:
|
||||||
'stats',
|
attr, val, reply_txt = validate_user_config_request(
|
||||||
{}
|
message.text)
|
||||||
)
|
|
||||||
stats = resp.result
|
|
||||||
|
|
||||||
stats_str = f'generated: {stats["generated"]}\n'
|
logging.info(f'user config update: {attr} to {val}')
|
||||||
stats_str += f'joined: {stats["joined"]}\n'
|
await db_call('update_user_config',
|
||||||
stats_str += f'role: {stats["role"]}\n'
|
user, req.params['attr'], req.params['val'])
|
||||||
|
logging.info('done')
|
||||||
|
|
||||||
await bot.reply_to(
|
except BaseException as e:
|
||||||
message, stats_str)
|
reply_txt = str(e)
|
||||||
|
|
||||||
@bot.message_handler(commands=['donate'])
|
finally:
|
||||||
async def donation_info(message):
|
await bot.reply_to(message, reply_txt)
|
||||||
await bot.reply_to(
|
|
||||||
message, DONATION_INFO)
|
|
||||||
|
|
||||||
@bot.message_handler(commands=['say'])
|
@bot.message_handler(commands=['stats'])
|
||||||
async def say(message):
|
async def user_stats(message):
|
||||||
chat = message.chat
|
|
||||||
user = message.from_user
|
|
||||||
|
|
||||||
if (chat.type == 'group') or (user.id != 383385940):
|
generated, joined, role = await db_call('get_user_stats', user)
|
||||||
return
|
|
||||||
|
|
||||||
await bot.send_message(GROUP_ID, message.text[4:])
|
stats_str = f'generated: {generated}\n'
|
||||||
|
stats_str += f'joined: {joined}\n'
|
||||||
|
stats_str += f'role: {role}\n'
|
||||||
|
|
||||||
|
await bot.reply_to(
|
||||||
|
message, stats_str)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['donate'])
|
||||||
|
async def donation_info(message):
|
||||||
|
await bot.reply_to(
|
||||||
|
message, DONATION_INFO)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['say'])
|
||||||
|
async def say(message):
|
||||||
|
chat = message.chat
|
||||||
|
user = message.from_user
|
||||||
|
|
||||||
|
if (chat.type == 'group') or (user.id != 383385940):
|
||||||
|
return
|
||||||
|
|
||||||
|
await bot.send_message(GROUP_ID, message.text[4:])
|
||||||
|
|
||||||
|
|
||||||
@bot.message_handler(func=lambda message: True)
|
@bot.message_handler(func=lambda message: True)
|
||||||
async def echo_message(message):
|
async def echo_message(message):
|
||||||
if message.text[0] == '/':
|
if message.text[0] == '/':
|
||||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||||
|
|
||||||
@bot.callback_query_handler(func=lambda call: True)
|
@bot.callback_query_handler(func=lambda call: True)
|
||||||
async def callback_query(call):
|
async def callback_query(call):
|
||||||
|
@ -289,4 +334,4 @@ async def run_skynet_telegram(
|
||||||
await _redo(call)
|
await _redo(call)
|
||||||
|
|
||||||
|
|
||||||
await aio_as_trio(bot.infinity_polling())
|
await aio_as_trio(bot.infinity_polling)()
|
||||||
|
|
|
@ -0,0 +1,341 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import zlib
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from typing import Callable, Awaitable, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import asynccontextmanager as acm
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import pynng
|
||||||
|
|
||||||
|
from pynng import TLSConfig, Context
|
||||||
|
|
||||||
|
from .protobuf import *
|
||||||
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_port():
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
s.bind(('', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def load_certs(
|
||||||
|
certs_dir: str,
|
||||||
|
cert_name: str,
|
||||||
|
key_name: str
|
||||||
|
):
|
||||||
|
certs_dir = Path(certs_dir).resolve()
|
||||||
|
tls_key_data = (certs_dir / key_name).read_bytes()
|
||||||
|
tls_key = serialization.load_pem_private_key(
|
||||||
|
tls_key_data,
|
||||||
|
password=None
|
||||||
|
)
|
||||||
|
|
||||||
|
tls_cert_data = (certs_dir / cert_name).read_bytes()
|
||||||
|
tls_cert = x509.load_pem_x509_certificate(
|
||||||
|
tls_cert_data
|
||||||
|
)
|
||||||
|
|
||||||
|
tls_whitelist = {}
|
||||||
|
for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
|
||||||
|
tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
|
||||||
|
cert_path.read_bytes()
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
SessionTLSConfig(
|
||||||
|
TLSConfig.MODE_SERVER,
|
||||||
|
own_key_string=tls_key_data,
|
||||||
|
own_cert_string=tls_cert_data
|
||||||
|
),
|
||||||
|
|
||||||
|
tls_whitelist
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_certs_client(
|
||||||
|
certs_dir: str,
|
||||||
|
cert_name: str,
|
||||||
|
key_name: str,
|
||||||
|
ca_name: Optional[str] = None
|
||||||
|
):
|
||||||
|
certs_dir = Path(certs_dir).resolve()
|
||||||
|
if not ca_name:
|
||||||
|
ca_name = 'brain.cert'
|
||||||
|
|
||||||
|
ca_cert_data = (certs_dir / ca_name).read_bytes()
|
||||||
|
|
||||||
|
tls_key_data = (certs_dir / key_name).read_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
tls_cert_data = (certs_dir / cert_name).read_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
tls_whitelist = {}
|
||||||
|
for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
|
||||||
|
tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
|
||||||
|
cert_path.read_bytes()
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
SessionTLSConfig(
|
||||||
|
TLSConfig.MODE_CLIENT,
|
||||||
|
own_key_string=tls_key_data,
|
||||||
|
own_cert_string=tls_cert_data,
|
||||||
|
ca_string=ca_cert_data
|
||||||
|
),
|
||||||
|
|
||||||
|
tls_whitelist
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionError(BaseException):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SessionTLSConfig(TLSConfig):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode,
|
||||||
|
server_name=None,
|
||||||
|
ca_string=None,
|
||||||
|
own_key_string=None,
|
||||||
|
own_cert_string=None,
|
||||||
|
auth_mode=None,
|
||||||
|
ca_files=None,
|
||||||
|
cert_key_file=None,
|
||||||
|
passwd=None
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
mode,
|
||||||
|
server_name=server_name,
|
||||||
|
ca_string=ca_string,
|
||||||
|
own_key_string=own_key_string,
|
||||||
|
own_cert_string=own_cert_string,
|
||||||
|
auth_mode=auth_mode,
|
||||||
|
ca_files=ca_files,
|
||||||
|
cert_key_file=cert_key_file,
|
||||||
|
passwd=passwd
|
||||||
|
)
|
||||||
|
|
||||||
|
if ca_string:
|
||||||
|
self.ca_cert = x509.load_pem_x509_certificate(ca_string)
|
||||||
|
|
||||||
|
self.cert = x509.load_pem_x509_certificate(own_cert_string)
|
||||||
|
self.key = serialization.load_pem_private_key(
|
||||||
|
own_key_string,
|
||||||
|
password=passwd
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionServer:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
addr: str,
|
||||||
|
msg_handler: Callable[
|
||||||
|
[SkynetRPCRequest, Context], Awaitable[SkynetRPCResponse]
|
||||||
|
],
|
||||||
|
cert_name: Optional[str] = None,
|
||||||
|
key_name: Optional[str] = None,
|
||||||
|
cert_dir: str = DEFAULT_CERTS_DIR,
|
||||||
|
recv_max_size = 0
|
||||||
|
):
|
||||||
|
self.addr = addr
|
||||||
|
self.msg_handler = msg_handler
|
||||||
|
|
||||||
|
self.cert_name = cert_name
|
||||||
|
self.tls_config = None
|
||||||
|
self.tls_whitelist = None
|
||||||
|
if cert_name and key_name:
|
||||||
|
self.cert_name = cert_name
|
||||||
|
self.tls_config, self.tls_whitelist = load_certs(
|
||||||
|
cert_dir, cert_name, key_name)
|
||||||
|
|
||||||
|
self.addr = 'tls+' + self.addr
|
||||||
|
|
||||||
|
self.recv_max_size = recv_max_size
|
||||||
|
|
||||||
|
async def _handle_msg(self, req: SkynetRPCRequest, ctx: Context):
|
||||||
|
resp = await self.msg_handler(req, ctx)
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
resp.auth.cert = 'skynet'
|
||||||
|
resp.auth.sig = sign_protobuf_msg(
|
||||||
|
resp, self.tls_config.key)
|
||||||
|
|
||||||
|
raw_msg = zlib.compress(resp.SerializeToString())
|
||||||
|
|
||||||
|
await ctx.asend(raw_msg)
|
||||||
|
|
||||||
|
ctx.close()
|
||||||
|
|
||||||
|
async def _listener (self, sock):
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
while True:
|
||||||
|
ctx = sock.new_context()
|
||||||
|
|
||||||
|
raw_msg = await ctx.arecv()
|
||||||
|
raw_size = len(raw_msg)
|
||||||
|
logging.debug(f'rpc server new msg {raw_size} bytes')
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = zlib.decompress(raw_msg)
|
||||||
|
msg_size = len(msg)
|
||||||
|
|
||||||
|
except zlib.error:
|
||||||
|
logging.warning(f'Zlib decompress error, dropping msg of size {len(raw_msg)}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.debug(f'msg after decompress {msg_size} bytes, +{msg_size - raw_size} bytes')
|
||||||
|
|
||||||
|
req = SkynetRPCRequest()
|
||||||
|
try:
|
||||||
|
req.ParseFromString(msg)
|
||||||
|
|
||||||
|
except google.protobuf.message.DecodeError:
|
||||||
|
logging.warning(f'Dropping malfomed msg of size {len(msg)}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.debug(f'msg method: {req.method}')
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
if req.auth.cert not in self.tls_whitelist:
|
||||||
|
logging.warning(
|
||||||
|
f'{req.auth.cert} not in tls whitelist')
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
verify_protobuf_msg(req, self.tls_whitelist[req.auth.cert])
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(
|
||||||
|
f'{req.cert} sent an unauthenticated msg')
|
||||||
|
continue
|
||||||
|
|
||||||
|
n.start_soon(self._handle_msg, req, ctx)
|
||||||
|
|
||||||
|
@acm
|
||||||
|
async def open(self):
|
||||||
|
with pynng.Rep0(
|
||||||
|
recv_max_size=self.recv_max_size
|
||||||
|
) as sock:
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
sock.tls_config = self.tls_config
|
||||||
|
|
||||||
|
sock.listen(self.addr)
|
||||||
|
|
||||||
|
logging.debug(f'server socket listening at {self.addr}')
|
||||||
|
|
||||||
|
async with trio.open_nursery() as n:
|
||||||
|
n.start_soon(self._listener, sock)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
|
||||||
|
finally:
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
|
||||||
|
logging.debug('server socket is off.')
|
||||||
|
|
||||||
|
|
||||||
|
class SessionClient:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connect_addr: str,
|
||||||
|
uid: str,
|
||||||
|
cert_name: Optional[str] = None,
|
||||||
|
key_name: Optional[str] = None,
|
||||||
|
ca_name: Optional[str] = None,
|
||||||
|
cert_dir: str = DEFAULT_CERTS_DIR,
|
||||||
|
recv_max_size = 0
|
||||||
|
):
|
||||||
|
self.uid = uid
|
||||||
|
self.connect_addr = connect_addr
|
||||||
|
|
||||||
|
self.cert_name = None
|
||||||
|
self.tls_config = None
|
||||||
|
self.tls_whitelist = None
|
||||||
|
self.tls_cert = None
|
||||||
|
self.tls_key = None
|
||||||
|
if cert_name and key_name:
|
||||||
|
self.cert_name = Path(cert_name).stem
|
||||||
|
self.tls_config, self.tls_whitelist = load_certs_client(
|
||||||
|
cert_dir, cert_name, key_name, ca_name=ca_name)
|
||||||
|
|
||||||
|
if not self.connect_addr.startswith('tls'):
|
||||||
|
self.connect_addr = 'tls+' + self.connect_addr
|
||||||
|
|
||||||
|
self.recv_max_size = recv_max_size
|
||||||
|
|
||||||
|
self._connected = False
|
||||||
|
self._sock = None
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self._sock = pynng.Req0(
|
||||||
|
recv_max_size=0,
|
||||||
|
name=self.uid
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
self._sock.tls_config = self.tls_config
|
||||||
|
|
||||||
|
logging.debug(f'client is dialing {self.connect_addr}...')
|
||||||
|
self._sock.dial(self.connect_addr, block=True)
|
||||||
|
self._connected = True
|
||||||
|
logging.debug(f'client is connected to {self.connect_addr}')
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self._sock.close()
|
||||||
|
self._connected = False
|
||||||
|
logging.debug(f'client disconnected.')
|
||||||
|
|
||||||
|
async def rpc(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
params: dict = {},
|
||||||
|
binext: Optional[bytes] = None,
|
||||||
|
timeout: float = 2.
|
||||||
|
):
|
||||||
|
if not self._connected:
|
||||||
|
raise SessionError('tried to use rpc without connecting')
|
||||||
|
|
||||||
|
req = SkynetRPCRequest()
|
||||||
|
req.uid = self.uid
|
||||||
|
req.method = method
|
||||||
|
req.params.update(params)
|
||||||
|
if binext:
|
||||||
|
logging.debug('added binary extension')
|
||||||
|
req.bin = binext
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
req.auth.cert = self.cert_name
|
||||||
|
req.auth.sig = sign_protobuf_msg(req, self.tls_config.key)
|
||||||
|
|
||||||
|
with trio.fail_after(timeout):
|
||||||
|
ctx = self._sock.new_context()
|
||||||
|
raw_req = zlib.compress(req.SerializeToString())
|
||||||
|
logging.debug(f'rpc client sending new msg {method} of size {len(raw_req)}')
|
||||||
|
await ctx.asend(raw_req)
|
||||||
|
logging.debug('sent, awaiting response...')
|
||||||
|
raw_resp = await ctx.arecv()
|
||||||
|
logging.debug(f'rpc client got response of size {len(raw_resp)}')
|
||||||
|
raw_resp = zlib.decompress(raw_resp)
|
||||||
|
|
||||||
|
resp = SkynetRPCResponse()
|
||||||
|
resp.ParseFromString(raw_resp)
|
||||||
|
ctx.close()
|
||||||
|
|
||||||
|
if self.tls_config:
|
||||||
|
verify_protobuf_msg(resp, self.tls_config.ca_cert)
|
||||||
|
|
||||||
|
return resp
|
|
@ -1,29 +1,4 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
|
|
||||||
from google.protobuf.json_format import MessageToDict
|
|
||||||
|
|
||||||
from .auth import *
|
from .auth import *
|
||||||
from .skynet_pb2 import *
|
from .skynet_pb2 import *
|
||||||
|
|
||||||
|
|
||||||
class Struct:
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DiffusionParameters(Struct):
|
|
||||||
algo: str
|
|
||||||
prompt: str
|
|
||||||
step: int
|
|
||||||
width: int
|
|
||||||
height: int
|
|
||||||
guidance: float
|
|
||||||
strength: float
|
|
||||||
seed: Optional[int]
|
|
||||||
image: bool # if true indicates a bytestream is next msg
|
|
||||||
upscaler: Optional[str]
|
|
||||||
|
|
|
@ -7,7 +7,8 @@ from hashlib import sha256
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
from OpenSSL.crypto import PKey, X509, verify, sign
|
from cryptography.hazmat.primitives import serialization, hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding
|
||||||
|
|
||||||
from .skynet_pb2 import *
|
from .skynet_pb2 import *
|
||||||
|
|
||||||
|
@ -46,20 +47,23 @@ def serialize_msg_deterministic(msg):
|
||||||
if field_descriptor.message_type.name == 'Struct':
|
if field_descriptor.message_type.name == 'Struct':
|
||||||
hash_dict(MessageToDict(getattr(msg, field_name)))
|
hash_dict(MessageToDict(getattr(msg, field_name)))
|
||||||
|
|
||||||
deterministic_msg = shasum.hexdigest()
|
deterministic_msg = shasum.digest()
|
||||||
|
|
||||||
return deterministic_msg
|
return deterministic_msg
|
||||||
|
|
||||||
|
|
||||||
def sign_protobuf_msg(msg, key: PKey):
|
def sign_protobuf_msg(msg, key):
|
||||||
return sign(
|
return key.sign(
|
||||||
key, serialize_msg_deterministic(msg), 'sha256').hex()
|
serialize_msg_deterministic(msg),
|
||||||
|
padding.PKCS1v15(),
|
||||||
|
hashes.SHA256()
|
||||||
|
).hex()
|
||||||
|
|
||||||
|
|
||||||
def verify_protobuf_msg(msg, cert: X509):
|
def verify_protobuf_msg(msg, cert):
|
||||||
return verify(
|
return cert.public_key().verify(
|
||||||
cert,
|
|
||||||
bytes.fromhex(msg.auth.sig),
|
bytes.fromhex(msg.auth.sig),
|
||||||
serialize_msg_deterministic(msg),
|
serialize_msg_deterministic(msg),
|
||||||
'sha256'
|
padding.PKCS1v15(),
|
||||||
|
hashes.SHA256()
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,18 +13,12 @@ message SkynetRPCRequest {
|
||||||
string uid = 1;
|
string uid = 1;
|
||||||
string method = 2;
|
string method = 2;
|
||||||
google.protobuf.Struct params = 3;
|
google.protobuf.Struct params = 3;
|
||||||
optional Auth auth = 4;
|
optional bytes bin = 4;
|
||||||
|
optional Auth auth = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SkynetRPCResponse {
|
message SkynetRPCResponse {
|
||||||
google.protobuf.Struct result = 1;
|
google.protobuf.Struct result = 1;
|
||||||
optional Auth auth = 2;
|
optional bytes bin = 2;
|
||||||
}
|
optional Auth auth = 3;
|
||||||
|
|
||||||
message DGPUBusMessage {
|
|
||||||
string rid = 1;
|
|
||||||
string nid = 2;
|
|
||||||
string method = 3;
|
|
||||||
google.protobuf.Struct params = 4;
|
|
||||||
optional Auth auth = 5;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
||||||
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x9c\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_auth\"\x80\x01\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x03 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_authb\x06proto3')
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
|
||||||
|
@ -24,9 +24,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
_AUTH._serialized_start=54
|
_AUTH._serialized_start=54
|
||||||
_AUTH._serialized_end=87
|
_AUTH._serialized_end=87
|
||||||
_SKYNETRPCREQUEST._serialized_start=90
|
_SKYNETRPCREQUEST._serialized_start=90
|
||||||
_SKYNETRPCREQUEST._serialized_end=220
|
_SKYNETRPCREQUEST._serialized_end=246
|
||||||
_SKYNETRPCRESPONSE._serialized_start=222
|
_SKYNETRPCRESPONSE._serialized_start=249
|
||||||
_SKYNETRPCRESPONSE._serialized_end=324
|
_SKYNETRPCRESPONSE._serialized_end=377
|
||||||
_DGPUBUSMESSAGE._serialized_start=327
|
|
||||||
_DGPUBUSMESSAGE._serialized_end=468
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import time
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -21,6 +22,10 @@ from huggingface_hub import login
|
||||||
from .constants import ALGOS
|
from .constants import ALGOS
|
||||||
|
|
||||||
|
|
||||||
|
def time_ms():
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
return Image.fromarray(img)
|
return Image.fromarray(img)
|
||||||
|
@ -164,3 +169,13 @@ def upscale(
|
||||||
|
|
||||||
|
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
def download_all_models(hf_token: str):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
|
||||||
|
login(token=hf_token)
|
||||||
|
for model in ALGOS:
|
||||||
|
print(f'DOWNLOADING {model.upper()}')
|
||||||
|
pipeline_for(model)
|
||||||
|
|
||||||
|
|
|
@ -3,89 +3,30 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import string
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import trio
|
|
||||||
import pytest
|
import pytest
|
||||||
import psycopg2
|
|
||||||
import trio_asyncio
|
|
||||||
|
|
||||||
from docker.types import Mount, DeviceRequest
|
from docker.types import Mount, DeviceRequest
|
||||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
|
||||||
|
|
||||||
from skynet.constants import *
|
from skynet.db import open_new_database
|
||||||
from skynet.brain import run_skynet
|
from skynet.brain import run_skynet
|
||||||
|
from skynet.network import get_random_port
|
||||||
|
from skynet.constants import *
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def postgres_db(dockerctl):
|
def postgres_db(dockerctl):
|
||||||
rpassword = ''.join(
|
with open_new_database() as db_params:
|
||||||
random.choice(string.ascii_lowercase)
|
yield db_params
|
||||||
for i in range(12))
|
|
||||||
password = ''.join(
|
|
||||||
random.choice(string.ascii_lowercase)
|
|
||||||
for i in range(12))
|
|
||||||
|
|
||||||
with dockerctl.run(
|
|
||||||
'postgres',
|
|
||||||
name='skynet-test-postgres',
|
|
||||||
ports={'5432/tcp': None},
|
|
||||||
environment={
|
|
||||||
'POSTGRES_PASSWORD': rpassword
|
|
||||||
}
|
|
||||||
) as containers:
|
|
||||||
container = containers[0]
|
|
||||||
# ip = container.attrs['NetworkSettings']['IPAddress']
|
|
||||||
port = container.ports['5432/tcp'][0]['HostPort']
|
|
||||||
host = f'localhost:{port}'
|
|
||||||
|
|
||||||
for log in container.logs(stream=True):
|
|
||||||
log = log.decode().rstrip()
|
|
||||||
logging.info(log)
|
|
||||||
if ('database system is ready to accept connections' in log or
|
|
||||||
'database system is shut down' in log):
|
|
||||||
break
|
|
||||||
|
|
||||||
# why print the system is ready to accept connections when its not
|
|
||||||
# postgres? wtf
|
|
||||||
time.sleep(1)
|
|
||||||
logging.info('creating skynet db...')
|
|
||||||
|
|
||||||
conn = psycopg2.connect(
|
|
||||||
user='postgres',
|
|
||||||
password=rpassword,
|
|
||||||
host='localhost',
|
|
||||||
port=port
|
|
||||||
)
|
|
||||||
logging.info('connected...')
|
|
||||||
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
|
||||||
with conn.cursor() as cursor:
|
|
||||||
cursor.execute(
|
|
||||||
f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
|
|
||||||
cursor.execute(
|
|
||||||
f'CREATE DATABASE {DB_NAME}')
|
|
||||||
cursor.execute(
|
|
||||||
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
logging.info('done.')
|
|
||||||
yield container, password, host
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def skynet_running(postgres_db):
|
async def skynet_running():
|
||||||
db_container, db_pass, db_host = postgres_db
|
async with run_skynet():
|
||||||
|
|
||||||
async with run_skynet(
|
|
||||||
db_pass=db_pass,
|
|
||||||
db_host=db_host
|
|
||||||
):
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,11 +40,13 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
||||||
|
|
||||||
cmds = []
|
cmds = []
|
||||||
for i in range(num_containers):
|
for i in range(num_containers):
|
||||||
|
dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
|
||||||
cmd = f'''
|
cmd = f'''
|
||||||
pip install -e . && \
|
pip install -e . && \
|
||||||
skynet run dgpu \
|
skynet run dgpu \
|
||||||
--algos=\'{json.dumps(initial_algos)}\' \
|
--algos=\'{json.dumps(initial_algos)}\' \
|
||||||
--uid=dgpu-{i}
|
--uid=dgpu-{i} \
|
||||||
|
--dgpu={dgpu_addr}
|
||||||
'''
|
'''
|
||||||
cmds.append(['bash', '-c', cmd])
|
cmds.append(['bash', '-c', cmd])
|
||||||
|
|
||||||
|
@ -114,16 +57,15 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
||||||
name='skynet-test-runtime-cuda',
|
name='skynet-test-runtime-cuda',
|
||||||
commands=cmds,
|
commands=cmds,
|
||||||
environment={
|
environment={
|
||||||
'HF_TOKEN': os.environ['HF_TOKEN'],
|
|
||||||
'HF_HOME': '/skynet/hf_home'
|
'HF_HOME': '/skynet/hf_home'
|
||||||
},
|
},
|
||||||
network='host',
|
network='host',
|
||||||
mounts=mounts,
|
mounts=mounts,
|
||||||
device_requests=devices,
|
device_requests=devices,
|
||||||
num=num_containers
|
num=num_containers,
|
||||||
) as containers:
|
) as containers:
|
||||||
yield containers
|
yield containers
|
||||||
|
|
||||||
#for i, container in enumerate(containers):
|
for i, container in enumerate(containers):
|
||||||
# logging.info(f'container {i} logs:')
|
logging.info(f'container {i} logs:')
|
||||||
# logging.info(container.logs().decode())
|
logging.info(container.logs().decode())
|
||||||
|
|
|
@ -12,29 +12,26 @@ from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pytest
|
import pytest
|
||||||
import trio_asyncio
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
|
||||||
from skynet.brain import SkynetDGPUComputeError
|
from skynet.brain import SkynetDGPUComputeError
|
||||||
from skynet.constants import *
|
from skynet.network import get_random_port, SessionServer
|
||||||
|
from skynet.protobuf import SkynetRPCResponse
|
||||||
from skynet.frontend import open_skynet_rpc
|
from skynet.frontend import open_skynet_rpc
|
||||||
|
from skynet.constants import *
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
async def wait_for_dgpus(session, amount: int, timeout: float = 30.0):
|
||||||
gpu_ready = False
|
gpu_ready = False
|
||||||
start_time = time.time()
|
with trio.fail_after(timeout):
|
||||||
current_time = time.time()
|
while not gpu_ready:
|
||||||
while not gpu_ready and (current_time - start_time) < timeout:
|
res = await session.rpc('dgpu_workers')
|
||||||
res = await rpc('dgpu_workers')
|
if res.result['ok'] >= amount:
|
||||||
if res.result['ok'] >= amount:
|
break
|
||||||
break
|
|
||||||
|
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
assert (current_time - start_time) < timeout
|
|
||||||
|
|
||||||
|
|
||||||
_images = set()
|
_images = set()
|
||||||
|
@ -48,34 +45,33 @@ async def check_request_img(
|
||||||
):
|
):
|
||||||
global _images
|
global _images
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
uid,
|
uid,
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as rpc_call:
|
res = await session.rpc(
|
||||||
res = await rpc_call(
|
'dgpu_call', {
|
||||||
'txt2img', {
|
'method': 'diffuse',
|
||||||
'prompt': 'red old tractor in a sunny wheat field',
|
'params': {
|
||||||
'step': 28,
|
'prompt': 'red old tractor in a sunny wheat field',
|
||||||
'width': width, 'height': height,
|
'step': 28,
|
||||||
'guidance': 7.5,
|
'width': width, 'height': height,
|
||||||
'seed': None,
|
'guidance': 7.5,
|
||||||
'algo': list(ALGOS.keys())[i],
|
'seed': None,
|
||||||
'upscaler': upscaler
|
'algo': list(ALGOS.keys())[i],
|
||||||
})
|
'upscaler': upscaler
|
||||||
|
}
|
||||||
|
},
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
if 'error' in res.result:
|
if 'error' in res.result:
|
||||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
if upscaler == 'x4':
|
img_raw = res.bin
|
||||||
width *= 4
|
|
||||||
height *= 4
|
|
||||||
|
|
||||||
img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
|
|
||||||
img_sha = sha256(img_raw).hexdigest()
|
img_sha = sha256(img_raw).hexdigest()
|
||||||
img = Image.frombytes(
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
'RGB', (width, height), img_raw)
|
|
||||||
|
|
||||||
if expect_unique and img_sha in _images:
|
if expect_unique and img_sha in _images:
|
||||||
raise ValueError('Duplicated image sha: {img_sha}')
|
raise ValueError('Duplicated image sha: {img_sha}')
|
||||||
|
@ -96,13 +92,12 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
||||||
then generate a smaller image to show gpu worker recovery
|
then generate a smaller image to show gpu worker recovery
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
|
|
||||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
await check_request_img(0, width=4096, height=4096)
|
await check_request_img(0, width=4096, height=4096)
|
||||||
|
@ -112,20 +107,35 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_worker(dgpu_workers):
|
||||||
|
'''Generate one image in a single dgpu worker
|
||||||
|
'''
|
||||||
|
|
||||||
|
with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
|
cert_name='whitelist/testing.cert',
|
||||||
|
key_name='testing.key'
|
||||||
|
) as session:
|
||||||
|
await wait_for_dgpus(session, 1)
|
||||||
|
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
||||||
async def test_dgpu_workers(dgpu_workers):
|
async def test_dgpu_worker_two_models(dgpu_workers):
|
||||||
'''Generate two images in a single dgpu worker using
|
'''Generate two images in a single dgpu worker using
|
||||||
two different models.
|
two different models.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
await check_request_img(1)
|
await check_request_img(1)
|
||||||
|
@ -138,14 +148,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
|
||||||
two different models.
|
two different models.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
logging.error('UPSCALE')
|
|
||||||
|
|
||||||
img = await check_request_img(0, upscaler='x4')
|
img = await check_request_img(0, upscaler='x4')
|
||||||
|
|
||||||
|
@ -157,13 +165,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
|
||||||
async def test_dgpu_workers_two(dgpu_workers):
|
async def test_dgpu_workers_two(dgpu_workers):
|
||||||
'''Generate two images in two separate dgpu workers
|
'''Generate two images in two separate dgpu workers
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 2, timeout=60)
|
||||||
await wait_for_dgpus(test_rpc, 2)
|
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
n.start_soon(check_request_img, 0)
|
n.start_soon(check_request_img, 0)
|
||||||
|
@ -175,13 +182,12 @@ async def test_dgpu_workers_two(dgpu_workers):
|
||||||
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
||||||
'''Generate an image using a non default model
|
'''Generate an image using a non default model
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
await check_request_img(5)
|
await check_request_img(5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,33 +197,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
|
||||||
'''Connect three dgpu workers, disconnect and check next_worker
|
'''Connect three dgpu workers, disconnect and check next_worker
|
||||||
rotation happens correctly
|
rotation happens correctly
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 3)
|
||||||
await wait_for_dgpus(test_rpc, 3)
|
|
||||||
|
|
||||||
res = await test_rpc('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 1
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 2
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
await check_request_img(0)
|
await check_request_img(0)
|
||||||
|
|
||||||
res = await test_rpc('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
@ -228,13 +233,12 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
'''Connect three dgpu workers, disconnect the first one and check
|
'''Connect three dgpu workers, disconnect the first one and check
|
||||||
next_worker rotation happens correctly
|
next_worker rotation happens correctly
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 3)
|
||||||
await wait_for_dgpus(test_rpc, 3)
|
|
||||||
|
|
||||||
await trio.sleep(3)
|
await trio.sleep(3)
|
||||||
|
|
||||||
|
@ -245,7 +249,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
||||||
|
|
||||||
dgpu_workers[0].wait()
|
dgpu_workers[0].wait()
|
||||||
|
|
||||||
res = await test_rpc('dgpu_workers')
|
res = await session.rpc('dgpu_workers')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
assert res.result['ok'] == 2
|
assert res.result['ok'] == 2
|
||||||
|
|
||||||
|
@ -258,26 +262,43 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
||||||
'''Mock a node that connects, gets a request but fails to
|
'''Mock a node that connects, gets a request but fails to
|
||||||
acknowledge it, then check skynet correctly drops the node
|
acknowledge it, then check skynet correctly drops the node
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
|
||||||
'test-ctx',
|
|
||||||
security=True,
|
|
||||||
cert_name='whitelist/testing',
|
|
||||||
key_name='testing'
|
|
||||||
) as rpc_call:
|
|
||||||
|
|
||||||
res = await rpc_call('dgpu_online')
|
async def mock_rpc(req, ctx):
|
||||||
assert 'ok' in res.result
|
resp = SkynetRPCResponse()
|
||||||
|
resp.result.update({'error': 'can\'t do it mate'})
|
||||||
|
return resp
|
||||||
|
|
||||||
await wait_for_dgpus(rpc_call, 1)
|
dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
|
||||||
|
mock_server = SessionServer(
|
||||||
|
dgpu_addr,
|
||||||
|
mock_rpc,
|
||||||
|
cert_name='whitelist/testing.cert',
|
||||||
|
key_name='testing.key'
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
async with mock_server.open():
|
||||||
await check_request_img(0)
|
with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
|
cert_name='whitelist/testing.cert',
|
||||||
|
key_name='testing.key'
|
||||||
|
) as session:
|
||||||
|
|
||||||
assert 'dgpu failed to acknowledge request' in str(e)
|
res = await session.rpc('dgpu_online', {
|
||||||
|
'dgpu_addr': dgpu_addr,
|
||||||
|
'cert': 'whitelist/testing.cert'
|
||||||
|
})
|
||||||
|
assert 'ok' in res.result
|
||||||
|
|
||||||
res = await rpc_call('dgpu_workers')
|
await wait_for_dgpus(session, 1)
|
||||||
assert 'ok' in res.result
|
|
||||||
assert res.result['ok'] == 0
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
|
await check_request_img(0)
|
||||||
|
|
||||||
|
assert 'can\'t do it mate' in str(e.value)
|
||||||
|
|
||||||
|
res = await session.rpc('dgpu_workers')
|
||||||
|
assert 'ok' in res.result
|
||||||
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -286,13 +307,12 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
|
||||||
'''Stop node while processing request to cause timeout and
|
'''Stop node while processing request to cause timeout and
|
||||||
then check skynet correctly drops the node.
|
then check skynet correctly drops the node.
|
||||||
'''
|
'''
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'test-ctx',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as test_rpc:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
|
|
||||||
async def check_request_img_raises():
|
async def check_request_img_raises():
|
||||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||||
|
@ -308,72 +328,62 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
|
||||||
assert ec == 0
|
assert ec == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
|
||||||
async def test_dgpu_heartbeat(dgpu_workers):
|
|
||||||
'''
|
|
||||||
'''
|
|
||||||
async with open_skynet_rpc(
|
|
||||||
'test-ctx',
|
|
||||||
security=True,
|
|
||||||
cert_name='whitelist/testing',
|
|
||||||
key_name='testing'
|
|
||||||
) as test_rpc:
|
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
|
||||||
await trio.sleep(120)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
async def test_dgpu_img2img(dgpu_workers):
|
async def test_dgpu_img2img(dgpu_workers):
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
with open_skynet_rpc(
|
||||||
'1',
|
'test-ctx',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as rpc_call:
|
await wait_for_dgpus(session, 1)
|
||||||
await wait_for_dgpus(rpc_call, 1)
|
|
||||||
|
|
||||||
|
await trio.sleep(2)
|
||||||
|
|
||||||
res = await rpc_call(
|
res = await session.rpc(
|
||||||
'txt2img', {
|
'dgpu_call', {
|
||||||
'prompt': 'red old tractor in a sunny wheat field',
|
'method': 'diffuse',
|
||||||
'step': 28,
|
'params': {
|
||||||
'width': 512, 'height': 512,
|
'prompt': 'red old tractor in a sunny wheat field',
|
||||||
'guidance': 7.5,
|
'step': 28,
|
||||||
'seed': None,
|
'width': 512, 'height': 512,
|
||||||
'algo': list(ALGOS.keys())[0],
|
'guidance': 7.5,
|
||||||
'upscaler': None
|
'seed': None,
|
||||||
})
|
'algo': list(ALGOS.keys())[0],
|
||||||
|
'upscaler': None
|
||||||
|
}
|
||||||
|
},
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
if 'error' in res.result:
|
if 'error' in res.result:
|
||||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
img_raw = res.result['img']
|
img_raw = res.bin
|
||||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
logging.info(img[:10])
|
|
||||||
img = Image.open(io.BytesIO(img))
|
|
||||||
|
|
||||||
img.save('txt2img.png')
|
img.save('txt2img.png')
|
||||||
|
|
||||||
res = await rpc_call(
|
res = await session.rpc(
|
||||||
'img2img', {
|
'dgpu_call', {
|
||||||
'prompt': 'red sports car in a sunny wheat field',
|
'method': 'diffuse',
|
||||||
'step': 28,
|
'params': {
|
||||||
'img': img_raw,
|
'prompt': 'red ferrari in a sunny wheat field',
|
||||||
'guidance': 12,
|
'step': 28,
|
||||||
'seed': None,
|
'guidance': 8,
|
||||||
'algo': list(ALGOS.keys())[0],
|
'strength': 0.7,
|
||||||
'upscaler': 'x4'
|
'seed': None,
|
||||||
})
|
'algo': list(ALGOS.keys())[0],
|
||||||
|
'upscaler': 'x4'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
binext=img_raw,
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
if 'error' in res.result:
|
if 'error' in res.result:
|
||||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
img_raw = res.result['img']
|
img_raw = res.bin
|
||||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
img = Image.open(io.BytesIO(img_raw))
|
||||||
logging.info(img[:10])
|
|
||||||
img = Image.open(io.BytesIO(img))
|
|
||||||
|
|
||||||
img.save('img2img.png')
|
img.save('img2img.png')
|
||||||
|
|
|
@ -9,6 +9,7 @@ import trio_asyncio
|
||||||
|
|
||||||
from skynet.brain import run_skynet
|
from skynet.brain import run_skynet
|
||||||
from skynet.structs import *
|
from skynet.structs import *
|
||||||
|
from skynet.network import SessionServer
|
||||||
from skynet.frontend import open_skynet_rpc
|
from skynet.frontend import open_skynet_rpc
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,53 +19,68 @@ async def test_skynet(skynet_running):
|
||||||
|
|
||||||
async def test_skynet_attempt_insecure(skynet_running):
|
async def test_skynet_attempt_insecure(skynet_running):
|
||||||
with pytest.raises(pynng.exceptions.NNGException) as e:
|
with pytest.raises(pynng.exceptions.NNGException) as e:
|
||||||
async with open_skynet_rpc('bad-actor'):
|
with open_skynet_rpc('bad-actor') as session:
|
||||||
...
|
with trio.fail_after(5):
|
||||||
|
await session.rpc('skynet_shutdown')
|
||||||
assert str(e.value) == 'Connection shutdown'
|
|
||||||
|
|
||||||
|
|
||||||
async def test_skynet_dgpu_connection_simple(skynet_running):
|
async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||||
async with open_skynet_rpc(
|
|
||||||
|
async def rpc_handler(req, ctx):
|
||||||
|
...
|
||||||
|
|
||||||
|
fake_dgpu_addr = 'tcp://127.0.0.1:41001'
|
||||||
|
rpc_server = SessionServer(
|
||||||
|
fake_dgpu_addr,
|
||||||
|
rpc_handler,
|
||||||
|
cert_name='whitelist/testing.cert',
|
||||||
|
key_name='testing.key'
|
||||||
|
)
|
||||||
|
|
||||||
|
with open_skynet_rpc(
|
||||||
'dgpu-0',
|
'dgpu-0',
|
||||||
security=True,
|
cert_name='whitelist/testing.cert',
|
||||||
cert_name='whitelist/testing',
|
key_name='testing.key'
|
||||||
key_name='testing'
|
) as session:
|
||||||
) as rpc_call:
|
|
||||||
# check 0 nodes are connected
|
# check 0 nodes are connected
|
||||||
res = await rpc_call('dgpu_workers')
|
res = await session.rpc('dgpu_workers')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# check next worker is None
|
# check next worker is None
|
||||||
res = await rpc_call('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == None
|
assert res.result['ok'] == None
|
||||||
|
|
||||||
# connect 1 dgpu
|
async with rpc_server.open() as rpc_server:
|
||||||
res = await rpc_call('dgpu_online')
|
# connect 1 dgpu
|
||||||
assert 'ok' in res.result
|
res = await session.rpc(
|
||||||
|
'dgpu_online', {
|
||||||
|
'dgpu_addr': fake_dgpu_addr,
|
||||||
|
'cert': 'whitelist/testing.cert'
|
||||||
|
})
|
||||||
|
assert 'ok' in res.result.keys()
|
||||||
|
|
||||||
# check 1 node is connected
|
# check 1 node is connected
|
||||||
res = await rpc_call('dgpu_workers')
|
res = await session.rpc('dgpu_workers')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == 1
|
assert res.result['ok'] == 1
|
||||||
|
|
||||||
# check next worker is 0
|
# check next worker is 0
|
||||||
res = await rpc_call('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# disconnect 1 dgpu
|
# disconnect 1 dgpu
|
||||||
res = await rpc_call('dgpu_offline')
|
res = await session.rpc('dgpu_offline')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
|
|
||||||
# check 0 nodes are connected
|
# check 0 nodes are connected
|
||||||
res = await rpc_call('dgpu_workers')
|
res = await session.rpc('dgpu_workers')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == 0
|
assert res.result['ok'] == 0
|
||||||
|
|
||||||
# check next worker is None
|
# check next worker is None
|
||||||
res = await rpc_call('dgpu_next')
|
res = await session.rpc('dgpu_next')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result.keys()
|
||||||
assert res.result['ok'] == None
|
assert res.result['ok'] == None
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from skynet.db import open_new_database
|
||||||
|
from skynet.brain import run_skynet
|
||||||
|
from skynet.config import load_skynet_ini
|
||||||
|
from skynet.frontend.telegram import run_skynet_telegram
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
'''You will need a telegram bot token configured on skynet.ini for this
|
||||||
|
'''
|
||||||
|
with open_new_database() as db_params:
|
||||||
|
db_container, db_pass, db_host = db_params
|
||||||
|
config = load_skynet_ini()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await run_skynet_telegram(
|
||||||
|
'telegram-test',
|
||||||
|
config['skynet.telegram-test']['token'],
|
||||||
|
db_host=db_host,
|
||||||
|
db_pass=db_pass
|
||||||
|
)
|
||||||
|
|
||||||
|
trio.run(main)
|
Loading…
Reference in New Issue