mirror of https://github.com/skygpu/skynet.git
commit
79901c85ca
|
@ -1,3 +1,4 @@
|
|||
skynet.ini
|
||||
.python-version
|
||||
hf_home
|
||||
outputs
|
||||
|
|
|
@ -32,3 +32,4 @@ env HF_HOME /hf_home
|
|||
copy scripts scripts
|
||||
copy tests tests
|
||||
|
||||
expose 40000-45000
|
||||
|
|
665
LICENSE
665
LICENSE
|
@ -1,11 +1,662 @@
|
|||
A menos que sea especificamente indicado en el cabezal del archivo, se reservan
|
||||
todos los derechos sobre este codigo por parte de:
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Guillermo Rodriguez, guillermor@fing.edu.uy
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
ENGLISH LICENSE:
|
||||
Preamble
|
||||
|
||||
Unless specifically indicated in the file header, all rights to this code are
|
||||
reserved by:
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published
|
||||
by the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
Guillermo Rodriguez, guillermor@.edu.uy
|
||||
|
|
|
@ -3,4 +3,4 @@ pytest
|
|||
pytest-trio
|
||||
psycopg2-binary
|
||||
|
||||
git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
|
||||
git+https://github.com/guilledk/pytest-dockerctl.git@multi_names#egg=pytest-dockerctl
|
||||
|
|
|
@ -9,3 +9,5 @@ protobuf
|
|||
pyOpenSSL
|
||||
trio_asyncio
|
||||
pyTelegramBotAPI
|
||||
|
||||
git+https://github.com/goodboy/tractor.git@master#egg=tractor
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
[skynet]
|
||||
certs_dir = certs
|
||||
|
||||
[skynet.dgpu]
|
||||
hf_home = hf_home
|
||||
hf_token = hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx
|
||||
|
||||
[skynet.telegram]
|
||||
token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
[skynet.telegram-test]
|
||||
token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
504
skynet/brain.py
504
skynet/brain.py
|
@ -1,35 +1,24 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import zlib
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from uuid import UUID
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from collections import OrderedDict
|
||||
|
||||
import trio
|
||||
import pynng
|
||||
import trio_asyncio
|
||||
|
||||
from pynng import TLSConfig
|
||||
from OpenSSL.crypto import (
|
||||
load_privatekey,
|
||||
load_certificate,
|
||||
FILETYPE_PEM
|
||||
)
|
||||
from pynng import Context
|
||||
|
||||
from .db import *
|
||||
from .utils import time_ms
|
||||
from .network import *
|
||||
from .protobuf import *
|
||||
from .constants import *
|
||||
|
||||
from .protobuf import *
|
||||
|
||||
|
||||
class SkynetRPCBadRequest(BaseException):
|
||||
...
|
||||
|
||||
class SkynetDGPUOffline(BaseException):
|
||||
...
|
||||
|
||||
|
@ -44,39 +33,71 @@ class SkynetShutdownRequested(BaseException):
|
|||
|
||||
|
||||
@acm
|
||||
async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||
async def run_skynet(
|
||||
rpc_address: str = DEFAULT_RPC_ADDR
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('skynet is starting')
|
||||
|
||||
nodes = OrderedDict()
|
||||
wip_reqs = {}
|
||||
fin_reqs = {}
|
||||
heartbeats = {}
|
||||
next_worker: Optional[int] = None
|
||||
security = len(tls_whitelist) > 0
|
||||
|
||||
def connect_node(uid):
|
||||
def connect_node(req: SkynetRPCRequest):
|
||||
nonlocal next_worker
|
||||
nodes[uid] = {
|
||||
'task': None
|
||||
}
|
||||
logging.info(f'dgpu online: {uid}')
|
||||
|
||||
if not next_worker:
|
||||
next_worker = 0
|
||||
node_params = MessageToDict(req.params)
|
||||
logging.info(f'got node params {node_params}')
|
||||
|
||||
if 'dgpu_addr' not in node_params:
|
||||
raise SkynetRPCBadRequest(
|
||||
f'DGPU connection params don\'t include dgpu addr')
|
||||
|
||||
session = SessionClient(
|
||||
node_params['dgpu_addr'],
|
||||
'skynet',
|
||||
cert_name='brain.cert',
|
||||
key_name='brain.key',
|
||||
ca_name=node_params['cert']
|
||||
)
|
||||
try:
|
||||
session.connect()
|
||||
|
||||
node = {
|
||||
'task': None,
|
||||
'session': session
|
||||
}
|
||||
node.update(node_params)
|
||||
|
||||
nodes[req.uid] = node
|
||||
logging.info(f'DGPU node online: {req.uid}')
|
||||
|
||||
if not next_worker:
|
||||
next_worker = 0
|
||||
|
||||
except pynng.exceptions.ConnectionRefused:
|
||||
logging.warning(f'error while dialing dgpu node... dropping...')
|
||||
raise SkynetDGPUOffline('Connection to dgpu node addr failed.')
|
||||
|
||||
def disconnect_node(uid):
|
||||
nonlocal next_worker
|
||||
if uid not in nodes:
|
||||
logging.warning(f'Attempt to disconnect unknown node {uid}')
|
||||
return
|
||||
|
||||
i = list(nodes.keys()).index(uid)
|
||||
nodes[uid]['session'].disconnect()
|
||||
del nodes[uid]
|
||||
|
||||
if i < next_worker:
|
||||
next_worker -= 1
|
||||
|
||||
logging.warning(f'DGPU node offline: {uid}')
|
||||
|
||||
if len(nodes) == 0:
|
||||
logging.info('nw: None')
|
||||
logging.info('All nodes disconnected.')
|
||||
next_worker = None
|
||||
|
||||
logging.warning(f'dgpu offline: {uid}')
|
||||
|
||||
def is_worker_busy(nid: str):
|
||||
return nodes[nid]['task'] != None
|
||||
|
@ -90,8 +111,6 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
|
||||
def get_next_worker():
|
||||
nonlocal next_worker
|
||||
logging.info('get next_worker called')
|
||||
logging.info(f'pre next_worker: {next_worker}')
|
||||
|
||||
if next_worker == None:
|
||||
raise SkynetDGPUOffline('No workers connected, try again later')
|
||||
|
@ -113,392 +132,79 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
if next_worker >= len(nodes):
|
||||
next_worker = 0
|
||||
|
||||
logging.info(f'post next_worker: {next_worker}')
|
||||
|
||||
return nid
|
||||
|
||||
async def dgpu_heartbeat_service():
|
||||
nonlocal heartbeats
|
||||
while True:
|
||||
await trio.sleep(60)
|
||||
rid = uuid.uuid4().hex
|
||||
beat_msg = DGPUBusMessage(
|
||||
rid=rid,
|
||||
nid='',
|
||||
method='heartbeat'
|
||||
)
|
||||
heartbeats.clear()
|
||||
heartbeats[rid] = int(time.time() * 1000)
|
||||
await dgpu_bus.asend(beat_msg.SerializeToString())
|
||||
logging.info('sent heartbeat')
|
||||
|
||||
async def dgpu_bus_streamer():
|
||||
nonlocal wip_reqs, fin_reqs, heartbeats
|
||||
while True:
|
||||
raw_msg = await dgpu_bus.arecv()
|
||||
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
||||
msg = DGPUBusMessage()
|
||||
msg.ParseFromString(raw_msg)
|
||||
|
||||
if security:
|
||||
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
|
||||
|
||||
rid = msg.rid
|
||||
|
||||
if msg.method == 'heartbeat':
|
||||
sent_time = heartbeats[rid]
|
||||
delta = msg.params['time'] - sent_time
|
||||
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
|
||||
continue
|
||||
|
||||
if rid not in wip_reqs:
|
||||
continue
|
||||
|
||||
if msg.method == 'binary-reply':
|
||||
logging.info('bin reply, recv extra data')
|
||||
raw_img = await dgpu_bus.arecv()
|
||||
msg = (msg, raw_img)
|
||||
|
||||
fin_reqs[rid] = msg
|
||||
event = wip_reqs[rid]
|
||||
event.set()
|
||||
del wip_reqs[rid]
|
||||
|
||||
async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
nid = get_next_worker()
|
||||
idx = list(nodes.keys()).index(nid)
|
||||
logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
|
||||
rid = uuid.uuid4().hex
|
||||
ack_event = trio.Event()
|
||||
img_event = trio.Event()
|
||||
wip_reqs[rid] = ack_event
|
||||
|
||||
nodes[nid]['task'] = rid
|
||||
|
||||
dgpu_req = DGPUBusMessage(
|
||||
rid=rid,
|
||||
nid=nid,
|
||||
method='diffuse')
|
||||
dgpu_req.params.update(req.to_dict())
|
||||
|
||||
if security:
|
||||
dgpu_req.auth.cert = 'skynet'
|
||||
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
||||
|
||||
msg = dgpu_req.SerializeToString()
|
||||
if img_buf:
|
||||
logging.info(f'sending img of size {len(img_buf)} as attachment')
|
||||
logging.info(img_buf[:10])
|
||||
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
|
||||
|
||||
await dgpu_bus.asend(msg)
|
||||
|
||||
with trio.move_on_after(4):
|
||||
await ack_event.wait()
|
||||
|
||||
logging.info(f'ack event: {ack_event.is_set()}')
|
||||
|
||||
if not ack_event.is_set():
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
ack_msg = fin_reqs[rid]
|
||||
if 'ack' not in ack_msg.params:
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUOffline('dgpu failed to acknowledge request')
|
||||
|
||||
wip_reqs[rid] = img_event
|
||||
with trio.move_on_after(30):
|
||||
await img_event.wait()
|
||||
|
||||
logging.info(f'img event: {ack_event.is_set()}')
|
||||
|
||||
if not img_event.is_set():
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUComputeError('30 seconds timeout while processing request')
|
||||
|
||||
nodes[nid]['task'] = None
|
||||
|
||||
resp = fin_reqs[rid]
|
||||
del fin_reqs[rid]
|
||||
if isinstance(resp, tuple):
|
||||
meta, img = resp
|
||||
return rid, img, meta.params
|
||||
|
||||
raise SkynetDGPUComputeError(MessageToDict(resp.params))
|
||||
|
||||
|
||||
async def handle_user_request(rpc_ctx, req):
|
||||
try:
|
||||
async with db_pool.acquire() as conn:
|
||||
user = await get_or_create_user(conn, req.uid)
|
||||
|
||||
result = {}
|
||||
|
||||
match req.method:
|
||||
case 'txt2img':
|
||||
logging.info('txt2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
user_config.update(MessageToDict(req.params))
|
||||
|
||||
req = DiffusionParameters(**user_config, image=False)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||
logging.info('updated user stats.')
|
||||
|
||||
case 'img2img':
|
||||
logging.info('img2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
|
||||
params = MessageToDict(req.params)
|
||||
img_buf = bytes.fromhex(params['img'])
|
||||
del params['img']
|
||||
user_config.update(params)
|
||||
|
||||
req = DiffusionParameters(**user_config, image=True)
|
||||
|
||||
if not req.image:
|
||||
raise AssertionError('Didn\'t enable image flag for img2img?')
|
||||
|
||||
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||
logging.info('updated user stats.')
|
||||
|
||||
case 'redo':
|
||||
logging.info('redo')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
prompt = await get_last_prompt_of(conn, user)
|
||||
|
||||
if prompt:
|
||||
req = DiffusionParameters(
|
||||
prompt=prompt,
|
||||
**user_config,
|
||||
image=False
|
||||
)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
await update_user_stats(conn, user)
|
||||
logging.info('updated user stats.')
|
||||
|
||||
else:
|
||||
result = {
|
||||
'error': 'skynet_no_last_prompt',
|
||||
'message': 'No prompt to redo, do txt2img first'
|
||||
}
|
||||
|
||||
case 'config':
|
||||
logging.info('config')
|
||||
if req.params['attr'] in CONFIG_ATTRS:
|
||||
logging.info(f'update: {req.params}')
|
||||
await update_user_config(
|
||||
conn, user, req.params['attr'], req.params['val'])
|
||||
logging.info('done')
|
||||
|
||||
else:
|
||||
logging.warning(f'{req.params["attr"]} not in {CONFIG_ATTRS}')
|
||||
|
||||
case 'stats':
|
||||
logging.info('stats')
|
||||
generated, joined, role = await get_user_stats(conn, user)
|
||||
|
||||
result = {
|
||||
'generated': generated,
|
||||
'joined': joined.strftime(DATE_FORMAT),
|
||||
'role': role
|
||||
}
|
||||
|
||||
case _:
|
||||
logging.warn('unknown method')
|
||||
|
||||
except SkynetDGPUOffline as e:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_offline',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
except SkynetDGPUOverloaded as e:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_overloaded',
|
||||
'message': str(e),
|
||||
'nodes': len(nodes)
|
||||
}
|
||||
|
||||
except SkynetDGPUComputeError as e:
|
||||
result = {
|
||||
'error': 'skynet_dgpu_compute_error',
|
||||
'message': str(e)
|
||||
}
|
||||
except BaseException as e:
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
result = {
|
||||
'error': 'skynet_internal_error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
|
||||
result = {'ok': {}}
|
||||
resp = SkynetRPCResponse()
|
||||
resp.result.update(result)
|
||||
|
||||
if security:
|
||||
resp.auth.cert = 'skynet'
|
||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||
|
||||
logging.info('sending response')
|
||||
await rpc_ctx.asend(resp.SerializeToString())
|
||||
rpc_ctx.close()
|
||||
logging.info('done')
|
||||
|
||||
async def request_service(n):
|
||||
nonlocal next_worker
|
||||
while True:
|
||||
ctx = sock.new_context()
|
||||
req = SkynetRPCRequest()
|
||||
req.ParseFromString(await ctx.arecv())
|
||||
|
||||
if security:
|
||||
if req.auth.cert not in tls_whitelist:
|
||||
logging.warning(
|
||||
f'{req.cert} not in tls whitelist and security=True')
|
||||
continue
|
||||
|
||||
try:
|
||||
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
|
||||
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f'{req.cert} sent an unauthenticated msg with security=True')
|
||||
continue
|
||||
|
||||
result = {}
|
||||
|
||||
try:
|
||||
match req.method:
|
||||
case 'skynet_shutdown':
|
||||
raise SkynetShutdownRequested
|
||||
|
||||
case 'dgpu_online':
|
||||
connect_node(req.uid)
|
||||
connect_node(req)
|
||||
|
||||
case 'dgpu_call':
|
||||
nid = get_next_worker()
|
||||
idx = list(nodes.keys()).index(nid)
|
||||
node = nodes[nid]
|
||||
logging.info(f'dgpu_call {idx}/{len(nodes)} {nid} @ {node["dgpu_addr"]}')
|
||||
dgpu_time = await node['session'].rpc('dgpu_time')
|
||||
if 'ok' not in dgpu_time.result:
|
||||
status = MessageToDict(dgpu_time.result)
|
||||
logging.warning(json.dumps(status, indent=4))
|
||||
disconnect_node(nid)
|
||||
raise SkynetDGPUComputeError(status['error'])
|
||||
|
||||
dgpu_time = dgpu_time.result['ok']
|
||||
logging.info(f'ping to {nid}: {time_ms() - dgpu_time} ms')
|
||||
|
||||
try:
|
||||
dgpu_result = await node['session'].rpc(
|
||||
timeout=45, # give this 45 sec to run cause its compute
|
||||
binext=req.bin,
|
||||
**req.params
|
||||
)
|
||||
result = MessageToDict(dgpu_result.result)
|
||||
|
||||
if dgpu_result.bin:
|
||||
resp.bin = dgpu_result.bin
|
||||
|
||||
except trio.TooSlowError:
|
||||
result = {'error': 'timeout while processing request'}
|
||||
|
||||
case 'dgpu_offline':
|
||||
disconnect_node(req.uid)
|
||||
|
||||
case 'dgpu_workers':
|
||||
result = len(nodes)
|
||||
result = {'ok': len(nodes)}
|
||||
|
||||
case 'dgpu_next':
|
||||
result = next_worker
|
||||
result = {'ok': next_worker}
|
||||
|
||||
case 'heartbeat':
|
||||
logging.info('beat')
|
||||
result = {'time': time.time()}
|
||||
case 'skynet_shutdown':
|
||||
raise SkynetShutdownRequested
|
||||
|
||||
case _:
|
||||
n.start_soon(
|
||||
handle_user_request, ctx, req)
|
||||
continue
|
||||
logging.warning(f'Unknown method {req.method}')
|
||||
result = {'error': 'unknown method'}
|
||||
|
||||
resp = SkynetRPCResponse()
|
||||
resp.result.update({'ok': result})
|
||||
except BaseException as e:
|
||||
result = {'error': str(e)}
|
||||
|
||||
if security:
|
||||
resp.auth.cert = 'skynet'
|
||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||
resp.result.update(result)
|
||||
|
||||
await ctx.asend(resp.SerializeToString())
|
||||
return resp
|
||||
|
||||
ctx.close()
|
||||
rpc_server = SessionServer(
|
||||
rpc_address,
|
||||
rpc_handler,
|
||||
cert_name='brain.cert',
|
||||
key_name='brain.key'
|
||||
)
|
||||
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(dgpu_bus_streamer)
|
||||
n.start_soon(dgpu_heartbeat_service)
|
||||
n.start_soon(request_service, n)
|
||||
logging.info('starting rpc service')
|
||||
async with rpc_server.open():
|
||||
logging.info('rpc server is up')
|
||||
yield
|
||||
logging.info('stopping rpc service')
|
||||
n.cancel_scope.cancel()
|
||||
logging.info('skynet is shuting down...')
|
||||
|
||||
|
||||
@acm
|
||||
async def run_skynet(
|
||||
db_user: str = DB_USER,
|
||||
db_pass: str = DB_PASS,
|
||||
db_host: str = DB_HOST,
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
security: bool = True
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('skynet is starting')
|
||||
|
||||
tls_config = None
|
||||
if security:
|
||||
# load tls certs
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
|
||||
tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
|
||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||
|
||||
tls_whitelist = {}
|
||||
for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
|
||||
tls_whitelist[cert_path.stem] = load_certificate(
|
||||
FILETYPE_PEM, cert_path.read_text())
|
||||
|
||||
cert_start = tls_cert_data.index('\n') + 1
|
||||
logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
|
||||
logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
|
||||
|
||||
rpc_address = 'tls+' + rpc_address
|
||||
dgpu_address = 'tls+' + dgpu_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_SERVER,
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data)
|
||||
|
||||
with (
|
||||
pynng.Rep0(recv_max_size=0) as rpc_sock,
|
||||
pynng.Bus0(recv_max_size=0) as dgpu_bus
|
||||
):
|
||||
async with open_database_connection(
|
||||
db_user, db_pass, db_host) as db_pool:
|
||||
|
||||
logging.info('connected to db.')
|
||||
if security:
|
||||
rpc_sock.tls_config = tls_config
|
||||
dgpu_bus.tls_config = tls_config
|
||||
|
||||
rpc_sock.listen(rpc_address)
|
||||
dgpu_bus.listen(dgpu_address)
|
||||
|
||||
try:
|
||||
async with open_rpc_service(
|
||||
rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||
yield
|
||||
|
||||
except SkynetShutdownRequested:
|
||||
...
|
||||
|
||||
logging.info('disconnected from db.')
|
||||
logging.info('skynet down.')
|
||||
|
|
|
@ -17,8 +17,8 @@ if torch_enabled:
|
|||
from .dgpu import open_dgpu_node
|
||||
|
||||
from .brain import run_skynet
|
||||
from .config import *
|
||||
from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR
|
||||
|
||||
from .frontend.telegram import run_skynet_telegram
|
||||
|
||||
|
||||
|
@ -38,8 +38,8 @@ def skynet(*args, **kwargs):
|
|||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def txt2img(*args, **kwargs):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||
_, hf_token, _, cfg = init_env_from_config()
|
||||
utils.txt2img(hf_token, **kwargs)
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default='midj')
|
||||
|
@ -52,9 +52,9 @@ def txt2img(*args, **kwargs):
|
|||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
_, hf_token, _, cfg = init_env_from_config()
|
||||
utils.img2img(
|
||||
os.environ['HF_TOKEN'],
|
||||
hf_token,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
|
@ -76,6 +76,12 @@ def upscale(input, output, model):
|
|||
model_path=model)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
def download():
|
||||
_, hf_token, _, cfg = init_env_from_config()
|
||||
utils.download_all_models(hf_token)
|
||||
|
||||
|
||||
@skynet.group()
|
||||
def run(*args, **kwargs):
|
||||
pass
|
||||
|
@ -85,29 +91,17 @@ def run(*args, **kwargs):
|
|||
@click.option('--loglevel', '-l', default='warning', help='Logging level')
|
||||
@click.option(
|
||||
'--host', '-H', default=DEFAULT_RPC_ADDR)
|
||||
@click.option(
|
||||
'--host-dgpu', '-D', default=DEFAULT_DGPU_ADDR)
|
||||
@click.option(
|
||||
'--db-host', '-h', default='localhost:5432')
|
||||
@click.option(
|
||||
'--db-pass', '-p', default='password')
|
||||
def brain(
|
||||
loglevel: str,
|
||||
host: str,
|
||||
host_dgpu: str,
|
||||
db_host: str,
|
||||
db_pass: str
|
||||
host: str
|
||||
):
|
||||
async def _run_skynet():
|
||||
async with run_skynet(
|
||||
db_host=db_host,
|
||||
db_pass=db_pass,
|
||||
rpc_address=host,
|
||||
dgpu_address=host_dgpu
|
||||
rpc_address=host
|
||||
):
|
||||
await trio.sleep_forever()
|
||||
|
||||
trio_asyncio.run(_run_skynet)
|
||||
trio.run(_run_skynet)
|
||||
|
||||
|
||||
@run.command()
|
||||
|
@ -115,9 +109,9 @@ def brain(
|
|||
@click.option(
|
||||
'--uid', '-u', required=True)
|
||||
@click.option(
|
||||
'--key', '-k', default='dgpu')
|
||||
'--key', '-k', default='dgpu.key')
|
||||
@click.option(
|
||||
'--cert', '-c', default='whitelist/dgpu')
|
||||
'--cert', '-c', default='whitelist/dgpu.cert')
|
||||
@click.option(
|
||||
'--algos', '-a', default=json.dumps(['midj']))
|
||||
@click.option(
|
||||
|
@ -159,11 +153,11 @@ def telegram(
|
|||
cert: str,
|
||||
rpc: str
|
||||
):
|
||||
assert 'TG_TOKEN' in os.environ
|
||||
_, _, tg_token, cfg = init_env_from_config()
|
||||
trio_asyncio.run(
|
||||
partial(
|
||||
run_skynet_telegram,
|
||||
os.environ['TG_TOKEN'],
|
||||
tg_token,
|
||||
key_name=key,
|
||||
cert_name=cert,
|
||||
rpc_address=rpc
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from configparser import ConfigParser
|
||||
|
||||
from .constants import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def load_skynet_ini(
|
||||
file_path=DEFAULT_CONFIG_PATH
|
||||
):
|
||||
config = ConfigParser()
|
||||
config.read(file_path)
|
||||
return config
|
||||
|
||||
|
||||
def init_env_from_config(
|
||||
file_path=DEFAULT_CONFIG_PATH
|
||||
):
|
||||
config = load_skynet_ini()
|
||||
|
||||
if 'HF_TOKEN' in os.environ:
|
||||
hf_token = os.environ['HF_TOKEN']
|
||||
else:
|
||||
hf_token = config['skynet.dgpu']['hf_token']
|
||||
|
||||
if 'HF_HOME' in os.environ:
|
||||
hf_home = os.environ['HF_HOME']
|
||||
else:
|
||||
hf_home = config['skynet.dgpu']['hf_home']
|
||||
|
||||
if 'TG_TOKEN' in os.environ:
|
||||
tg_token = os.environ['TG_TOKEN']
|
||||
else:
|
||||
tg_token = config['skynet.telegram']['token']
|
||||
|
||||
return hf_home, hf_token, tg_token, config
|
|
@ -1,14 +1,9 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
VERSION = '0.1a8'
|
||||
VERSION = '0.1a9'
|
||||
|
||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
|
||||
|
||||
DB_HOST = 'localhost:5432'
|
||||
DB_USER = 'skynet'
|
||||
DB_PASS = 'password'
|
||||
DB_NAME = 'skynet'
|
||||
|
||||
ALGOS = {
|
||||
'midj': 'prompthero/openjourney',
|
||||
'stable': 'runwayml/stable-diffusion-v1-5',
|
||||
|
@ -118,6 +113,7 @@ DEFAULT_ALGO = 'midj'
|
|||
DEFAULT_ROLE = 'pleb'
|
||||
DEFAULT_UPSCALER = None
|
||||
|
||||
DEFAULT_CONFIG_PATH = 'skynet.ini'
|
||||
DEFAULT_CERTS_DIR = 'certs'
|
||||
DEFAULT_CERT_WHITELIST_DIR = 'whitelist'
|
||||
DEFAULT_CERT_SKYNET_PUB = 'brain.cert'
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from .proxy import open_database_connection
|
||||
|
||||
from .functions import open_new_database
|
|
@ -1,18 +1,21 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import trio
|
||||
import triopg
|
||||
import trio_asyncio
|
||||
import docker
|
||||
import psycopg2
|
||||
|
||||
from asyncpg.exceptions import UndefinedColumnError
|
||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||
|
||||
from .constants import *
|
||||
from ..constants import *
|
||||
|
||||
|
||||
DB_INIT_SQL = '''
|
||||
|
@ -75,29 +78,67 @@ def try_decode_uid(uid: str):
|
|||
return None, None
|
||||
|
||||
|
||||
@acm
|
||||
async def open_database_connection(
|
||||
db_user: str = DB_USER,
|
||||
db_pass: str = DB_PASS,
|
||||
db_host: str = DB_HOST,
|
||||
db_name: str = DB_NAME
|
||||
):
|
||||
async with trio_asyncio.open_loop() as loop:
|
||||
async with triopg.create_pool(
|
||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
|
||||
) as pool_conn:
|
||||
async with pool_conn.acquire() as conn:
|
||||
res = await conn.execute(f'''
|
||||
select distinct table_schema
|
||||
from information_schema.tables
|
||||
where table_schema = \'{db_name}\'
|
||||
''')
|
||||
if '1' in res:
|
||||
logging.info('schema already in db, skipping init')
|
||||
else:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
@cm
|
||||
def open_new_database():
|
||||
rpassword = ''.join(
|
||||
random.choice(string.ascii_lowercase)
|
||||
for i in range(12))
|
||||
password = ''.join(
|
||||
random.choice(string.ascii_lowercase)
|
||||
for i in range(12))
|
||||
|
||||
yield pool_conn
|
||||
dclient = docker.from_env()
|
||||
|
||||
container = dclient.containers.run(
|
||||
'postgres',
|
||||
name='skynet-test-postgres',
|
||||
ports={'5432/tcp': None},
|
||||
environment={
|
||||
'POSTGRES_PASSWORD': rpassword
|
||||
},
|
||||
detach=True,
|
||||
remove=True
|
||||
)
|
||||
|
||||
for log in container.logs(stream=True):
|
||||
log = log.decode().rstrip()
|
||||
logging.info(log)
|
||||
if ('database system is ready to accept connections' in log or
|
||||
'database system is shut down' in log):
|
||||
break
|
||||
|
||||
# ip = container.attrs['NetworkSettings']['IPAddress']
|
||||
container.reload()
|
||||
port = container.ports['5432/tcp'][0]['HostPort']
|
||||
host = f'localhost:{port}'
|
||||
|
||||
# why print the system is ready to accept connections when its not
|
||||
# postgres? wtf
|
||||
time.sleep(1)
|
||||
logging.info('creating skynet db...')
|
||||
|
||||
conn = psycopg2.connect(
|
||||
user='postgres',
|
||||
password=rpassword,
|
||||
host='localhost',
|
||||
port=port
|
||||
)
|
||||
logging.info('connected...')
|
||||
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f'CREATE USER skynet WITH PASSWORD \'{password}\'')
|
||||
cursor.execute(
|
||||
f'CREATE DATABASE skynet')
|
||||
cursor.execute(
|
||||
f'GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet')
|
||||
|
||||
conn.close()
|
||||
|
||||
logging.info('done.')
|
||||
yield container, password, host
|
||||
|
||||
container.stop()
|
||||
|
||||
|
||||
async def get_user(conn, uid: str):
|
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import importlib
|
||||
|
||||
from contextlib import asynccontextmanager as acm
|
||||
|
||||
import trio
|
||||
import tractor
|
||||
import asyncpg
|
||||
import asyncio
|
||||
import trio_asyncio
|
||||
|
||||
|
||||
_spawn_kwargs = {
|
||||
'infect_asyncio': True,
|
||||
}
|
||||
|
||||
|
||||
async def aio_db_proxy(
|
||||
to_trio: trio.MemorySendChannel,
|
||||
from_trio: asyncio.Queue,
|
||||
db_user: str = 'skynet',
|
||||
db_pass: str = 'password',
|
||||
db_host: str = 'localhost:5432',
|
||||
db_name: str = 'skynet'
|
||||
) -> None:
|
||||
db = importlib.import_module('skynet.db.functions')
|
||||
|
||||
pool = await asyncpg.create_pool(
|
||||
dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}')
|
||||
|
||||
async with pool_conn.acquire() as conn:
|
||||
res = await conn.execute(f'''
|
||||
select distinct table_schema
|
||||
from information_schema.tables
|
||||
where table_schema = \'{db_name}\'
|
||||
''')
|
||||
if '1' in res:
|
||||
logging.info('schema already in db, skipping init')
|
||||
else:
|
||||
await conn.execute(DB_INIT_SQL)
|
||||
|
||||
# a first message must be sent **from** this ``asyncio``
|
||||
# task or the ``trio`` side will never unblock from
|
||||
# ``tractor.to_asyncio.open_channel_from():``
|
||||
to_trio.send_nowait('start')
|
||||
|
||||
# XXX: this uses an ``from_trio: asyncio.Queue`` currently but we
|
||||
# should probably offer something better.
|
||||
while True:
|
||||
msg = await from_trio.get()
|
||||
|
||||
method = getattr(db, msg.get('method'))
|
||||
args = getattr(db, msg.get('args', []))
|
||||
kwargs = getattr(db, msg.get('kwargs', {}))
|
||||
|
||||
async with pool_conn.acquire() as conn:
|
||||
result = await method(conn, *args, **kwargs)
|
||||
to_trio.send_nowait(result)
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def trio_to_aio_db_proxy(
|
||||
ctx: tractor.Context,
|
||||
db_user: str = 'skynet',
|
||||
db_pass: str = 'password',
|
||||
db_host: str = 'localhost:5432',
|
||||
db_name: str = 'skynet'
|
||||
):
|
||||
# this will block until the ``asyncio`` task sends a "first"
|
||||
# message.
|
||||
async with tractor.to_asyncio.open_channel_from(
|
||||
aio_db_proxy,
|
||||
db_user=db_user,
|
||||
db_pass=db_pass,
|
||||
db_host=db_host,
|
||||
db_name=db_name
|
||||
) as (first, chan):
|
||||
|
||||
assert first == 'start'
|
||||
await ctx.started(first)
|
||||
|
||||
async with ctx.open_stream() as stream:
|
||||
|
||||
async for msg in stream:
|
||||
await chan.send(msg)
|
||||
|
||||
out = await chan.receive()
|
||||
# echo back to parent actor-task
|
||||
await stream.send(out)
|
||||
|
||||
|
||||
@acm
|
||||
async def open_database_connection(
|
||||
db_user: str = 'skynet',
|
||||
db_pass: str = 'password',
|
||||
db_host: str = 'localhost:5432',
|
||||
db_name: str = 'skynet'
|
||||
):
|
||||
async with tractor.open_nursery() as n:
|
||||
p = await n.start_actor(
|
||||
'aio_db_proxy',
|
||||
enable_modules=[__name__],
|
||||
infect_asyncio=True,
|
||||
)
|
||||
async with p.open_context(
|
||||
trio_to_aio_db_proxy,
|
||||
db_user=db_user,
|
||||
db_pass=db_pass,
|
||||
db_host=db_host,
|
||||
db_name=db_name
|
||||
) as (ctx, first):
|
||||
async with ctx.open_stream() as stream:
|
||||
|
||||
async def _db_pc(method: str, *args, **kwargs):
|
||||
await stream.send({
|
||||
'method': method,
|
||||
'args': args,
|
||||
'kwargs': kwargs
|
||||
})
|
||||
return await stream.receive()
|
||||
|
||||
yield _db_pc
|
405
skynet/dgpu.py
405
skynet/dgpu.py
|
@ -2,29 +2,17 @@
|
|||
|
||||
import gc
|
||||
import io
|
||||
import trio
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
import zlib
|
||||
import random
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import ExitStack
|
||||
|
||||
import pynng
|
||||
import trio
|
||||
import torch
|
||||
|
||||
from pynng import TLSConfig
|
||||
from OpenSSL.crypto import (
|
||||
load_privatekey,
|
||||
load_certificate,
|
||||
FILETYPE_PEM
|
||||
)
|
||||
from pynng import Context
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
|
@ -34,12 +22,9 @@ from realesrgan import RealESRGANer
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from .utils import (
|
||||
pipeline_for,
|
||||
convert_from_cv2_to_image, convert_from_image_to_cv2
|
||||
)
|
||||
from .utils import *
|
||||
from .network import *
|
||||
from .protobuf import *
|
||||
from .frontend import open_skynet_rpc
|
||||
from .constants import *
|
||||
|
||||
|
||||
|
@ -64,65 +49,16 @@ class DGPUComputeError(BaseException):
|
|||
...
|
||||
|
||||
|
||||
class ReconnectingBus:
|
||||
|
||||
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
||||
self.address = address
|
||||
self.tls_config = tls_config
|
||||
|
||||
self._stack = ExitStack()
|
||||
self._sock = None
|
||||
self._closed = True
|
||||
|
||||
def connect(self):
|
||||
self._sock = self._stack.enter_context(
|
||||
pynng.Bus0(recv_max_size=0))
|
||||
self._sock.tls_config = self.tls_config
|
||||
self._sock.dial(self.address)
|
||||
self._closed = False
|
||||
|
||||
async def arecv(self):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.arecv()
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
async def asend(self, msg):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.asend(msg)
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
self._stack.close()
|
||||
self._stack = ExitStack()
|
||||
self._closed = True
|
||||
|
||||
def reconnect(self):
|
||||
self.close()
|
||||
self.connect()
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
unique_id: str,
|
||||
key_name: Optional[str],
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
dgpu_address: str = DEFAULT_DGPU_ADDR,
|
||||
initial_algos: Optional[List[str]] = None,
|
||||
security: bool = True
|
||||
initial_algos: Optional[List[str]] = None
|
||||
):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.info(f'starting dgpu node!')
|
||||
|
||||
name = uuid.uuid4()
|
||||
|
||||
logging.info(f'loading models...')
|
||||
|
||||
upscaler = init_upscaler()
|
||||
|
@ -141,241 +77,140 @@ async def open_dgpu_node(
|
|||
logging.info('memory summary:')
|
||||
logging.info('\n' + torch.cuda.memory_summary())
|
||||
|
||||
async def gpu_compute_one(ireq: DiffusionParameters, image=None):
|
||||
algo = ireq.algo + 'img' if image else ireq.algo
|
||||
if algo not in models:
|
||||
least_used = list(models.keys())[0]
|
||||
for model in models:
|
||||
if models[least_used]['generated'] > models[model]['generated']:
|
||||
least_used = model
|
||||
async def gpu_compute_one(method: str, params: dict, binext: Optional[bytes] = None):
|
||||
match method:
|
||||
case 'diffuse':
|
||||
image = None
|
||||
algo = params['algo']
|
||||
if binext:
|
||||
algo += 'img'
|
||||
image = Image.open(io.BytesIO(binext))
|
||||
w, h = image.size
|
||||
logging.info(f'user sent img of size {image.size}')
|
||||
|
||||
del models[least_used]
|
||||
gc.collect()
|
||||
if w > 512 or h > 512:
|
||||
image.thumbnail((512, 512))
|
||||
logging.info(f'resized it to {image.size}')
|
||||
|
||||
models[algo] = {
|
||||
'pipe': pipeline_for(ireq.algo, image=True if image else False),
|
||||
'generated': 0
|
||||
}
|
||||
if algo not in models:
|
||||
logging.info(f'{algo} not in loaded models, swapping...')
|
||||
least_used = list(models.keys())[0]
|
||||
for model in models:
|
||||
if models[least_used]['generated'] > models[model]['generated']:
|
||||
least_used = model
|
||||
|
||||
_params = {}
|
||||
if ireq.image:
|
||||
_params['image'] = image
|
||||
_params['strength'] = ireq.strength
|
||||
del models[least_used]
|
||||
gc.collect()
|
||||
|
||||
else:
|
||||
_params['width'] = int(ireq.width)
|
||||
_params['height'] = int(ireq.height)
|
||||
models[algo] = {
|
||||
'pipe': pipeline_for(params['algo'], image=True if binext else False),
|
||||
'generated': 0
|
||||
}
|
||||
logging.info(f'swapping done.')
|
||||
|
||||
try:
|
||||
image = models[algo]['pipe'](
|
||||
ireq.prompt,
|
||||
**_params,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=int(ireq.step),
|
||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||
).images[0]
|
||||
_params = {}
|
||||
logging.info(method)
|
||||
logging.info(json.dumps(params, indent=4))
|
||||
logging.info(f'binext: {len(binext) if binext else 0} bytes')
|
||||
if binext:
|
||||
_params['image'] = image
|
||||
_params['strength'] = params['strength']
|
||||
|
||||
if ireq.upscaler == 'x4':
|
||||
logging.info(f'size: {len(image.tobytes())}')
|
||||
logging.info('performing upscale...')
|
||||
input_img = image.convert('RGB')
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
else:
|
||||
_params['width'] = int(params['width'])
|
||||
_params['height'] = int(params['height'])
|
||||
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
logging.info('done')
|
||||
try:
|
||||
image = models[algo]['pipe'](
|
||||
params['prompt'],
|
||||
**_params,
|
||||
guidance_scale=params['guidance'],
|
||||
num_inference_steps=int(params['step']),
|
||||
generator=torch.Generator("cuda").manual_seed(
|
||||
int(params['seed']) if params['seed'] else random.randint(0, 2 ** 64)
|
||||
)
|
||||
).images[0]
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format='PNG')
|
||||
raw_img = img_byte_arr.getvalue()
|
||||
logging.info(f'final img size {len(raw_img)} bytes.')
|
||||
if params['upscaler'] == 'x4':
|
||||
logging.info(f'size: {len(image.tobytes())}')
|
||||
logging.info('performing upscale...')
|
||||
input_img = image.convert('RGB')
|
||||
up_img, _ = upscaler.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
return raw_img
|
||||
image = convert_from_cv2_to_image(up_img)
|
||||
logging.info('done')
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise DGPUComputeError(str(e))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format='PNG')
|
||||
raw_img = img_byte_arr.getvalue()
|
||||
logging.info(f'final img size {len(raw_img)} bytes.')
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
return raw_img
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise DGPUComputeError(str(e))
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
|
||||
result = {}
|
||||
resp = SkynetRPCResponse()
|
||||
|
||||
match req.method:
|
||||
case 'dgpu_time':
|
||||
result = {'ok': time_ms()}
|
||||
|
||||
case _:
|
||||
logging.debug(f'dgpu got one request: {req.method}')
|
||||
try:
|
||||
resp.bin = await gpu_compute_one(
|
||||
req.method, MessageToDict(req.params),
|
||||
binext=req.bin if req.bin else None
|
||||
)
|
||||
logging.debug(f'dgpu processed one request')
|
||||
|
||||
except DGPUComputeError as e:
|
||||
result = {'error': str(e)}
|
||||
|
||||
resp.result.update(result)
|
||||
return resp
|
||||
|
||||
rpc_server = SessionServer(
|
||||
dgpu_address,
|
||||
rpc_handler,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
)
|
||||
skynet_rpc = SessionClient(
|
||||
rpc_address,
|
||||
unique_id,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
)
|
||||
skynet_rpc.connect()
|
||||
|
||||
|
||||
async with (
|
||||
open_skynet_rpc(
|
||||
unique_id,
|
||||
rpc_address=rpc_address,
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
async with rpc_server.open() as rpc_server:
|
||||
res = await skynet_rpc.rpc(
|
||||
'dgpu_online', {
|
||||
'dgpu_addr': rpc_server.addr,
|
||||
'cert': cert_name
|
||||
})
|
||||
|
||||
tls_config = None
|
||||
if security:
|
||||
# load tls certs
|
||||
if not key_name:
|
||||
key_name = cert_name
|
||||
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
|
||||
skynet_cert_path = certs_dir / 'brain.cert'
|
||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||
tls_key_path = certs_dir / f'{key_name}.key'
|
||||
|
||||
cert_name = tls_cert_path.stem
|
||||
|
||||
skynet_cert_data = skynet_cert_path.read_text()
|
||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||
|
||||
tls_cert_data = tls_cert_path.read_text()
|
||||
|
||||
tls_key_data = tls_key_path.read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
logging.info(f'skynet cert: {skynet_cert_path}')
|
||||
logging.info(f'dgpu cert: {tls_cert_path}')
|
||||
logging.info(f'dgpu key: {tls_key_path}')
|
||||
|
||||
dgpu_address = 'tls+' + dgpu_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data,
|
||||
ca_string=skynet_cert_data)
|
||||
|
||||
logging.info(f'connecting to {dgpu_address}')
|
||||
|
||||
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
||||
dgpu_bus.connect()
|
||||
|
||||
last_msg = time.time()
|
||||
async def connection_refresher(refresh_time: int = 120):
|
||||
nonlocal last_msg
|
||||
while True:
|
||||
now = time.time()
|
||||
last_msg_time_delta = now - last_msg
|
||||
logging.info(f'time since last msg: {last_msg_time_delta}')
|
||||
if last_msg_time_delta > refresh_time:
|
||||
dgpu_bus.reconnect()
|
||||
logging.info('reconnected!')
|
||||
last_msg = now
|
||||
|
||||
await trio.sleep(refresh_time)
|
||||
|
||||
n.start_soon(connection_refresher)
|
||||
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_bus.arecv()
|
||||
|
||||
img = None
|
||||
if b'BINEXT' in msg:
|
||||
header, msg, img_raw = msg.split(b'%$%$')
|
||||
logging.info(f'got img attachment of size {len(img_raw)}')
|
||||
logging.info(img_raw[:10])
|
||||
raw_img = zlib.decompress(img_raw)
|
||||
logging.info(raw_img[:10])
|
||||
img = Image.open(io.BytesIO(raw_img))
|
||||
w, h = img.size
|
||||
logging.info(f'user sent img of size {img.size}')
|
||||
|
||||
if w > 512 or h > 512:
|
||||
img.thumbnail((512, 512))
|
||||
logging.info(f'resized it to {img.size}')
|
||||
|
||||
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(msg)
|
||||
last_msg = time.time()
|
||||
|
||||
if req.method == 'heartbeat':
|
||||
rep = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=unique_id,
|
||||
method=req.method
|
||||
)
|
||||
rep.params.update({'time': int(time.time() * 1000)})
|
||||
|
||||
if security:
|
||||
rep.auth.cert = cert_name
|
||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||
|
||||
await dgpu_bus.asend(rep.SerializeToString())
|
||||
logging.info('heartbeat reply')
|
||||
continue
|
||||
|
||||
if req.nid != unique_id:
|
||||
logging.info(
|
||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||
continue
|
||||
|
||||
if security:
|
||||
verify_protobuf_msg(req, skynet_cert)
|
||||
|
||||
|
||||
ack_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid
|
||||
)
|
||||
ack_resp.params.update({'ack': {}})
|
||||
|
||||
if security:
|
||||
ack_resp.auth.cert = cert_name
|
||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||
|
||||
# send ack
|
||||
await dgpu_bus.asend(ack_resp.SerializeToString())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img_req = DiffusionParameters(**req.params)
|
||||
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req, image=img)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
method='binary-reply'
|
||||
)
|
||||
img_resp.params.update({
|
||||
'len': len(img),
|
||||
'meta': img_req.to_dict()
|
||||
})
|
||||
|
||||
except DGPUComputeError as e:
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid
|
||||
)
|
||||
img_resp.params.update({'error': str(e)})
|
||||
|
||||
|
||||
if security:
|
||||
img_resp.auth.cert = cert_name
|
||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||
|
||||
# send final image
|
||||
logging.info('sending img back...')
|
||||
raw_msg = img_resp.SerializeToString()
|
||||
await dgpu_bus.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_bus.asend(zlib.compress(img))
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
await trio.sleep_forever()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
n.cancel_scope.cancel()
|
||||
dgpu_bus.close()
|
||||
|
||||
finally:
|
||||
res = await rpc_call('dgpu_offline')
|
||||
res = await skynet_rpc.rpc('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
|
||||
from typing import Union, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import pynng
|
||||
|
||||
|
@ -17,6 +17,7 @@ from OpenSSL.crypto import (
|
|||
|
||||
from google.protobuf.struct_pb2 import Struct
|
||||
|
||||
from ..network import SessionClient
|
||||
from ..constants import *
|
||||
|
||||
from ..protobuf.auth import *
|
||||
|
@ -39,75 +40,23 @@ class ConfigSizeDivisionByEight(BaseException):
|
|||
...
|
||||
|
||||
|
||||
@acm
|
||||
async def open_skynet_rpc(
|
||||
@cm
|
||||
def open_skynet_rpc(
|
||||
unique_id: str,
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
security: bool = False,
|
||||
cert_name: Optional[str] = None,
|
||||
key_name: Optional[str] = None
|
||||
):
|
||||
tls_config = None
|
||||
|
||||
if security:
|
||||
# load tls certs
|
||||
if not key_name:
|
||||
key_name = cert_name
|
||||
|
||||
certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
|
||||
|
||||
skynet_cert_data = (certs_dir / 'brain.cert').read_text()
|
||||
skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
|
||||
|
||||
tls_cert_path = certs_dir / f'{cert_name}.cert'
|
||||
tls_cert_data = tls_cert_path.read_text()
|
||||
tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
|
||||
cert_name = tls_cert_path.stem
|
||||
|
||||
tls_key_data = (certs_dir / f'{key_name}.key').read_text()
|
||||
tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
|
||||
|
||||
rpc_address = 'tls+' + rpc_address
|
||||
tls_config = TLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data,
|
||||
ca_string=skynet_cert_data)
|
||||
|
||||
with pynng.Req0(recv_max_size=0) as sock:
|
||||
if security:
|
||||
sock.tls_config = tls_config
|
||||
|
||||
sock.dial(rpc_address)
|
||||
|
||||
async def _rpc_call(
|
||||
method: str,
|
||||
params: dict = {},
|
||||
uid: Optional[str] = None
|
||||
):
|
||||
req = SkynetRPCRequest()
|
||||
req.uid = uid if uid else unique_id
|
||||
req.method = method
|
||||
req.params.update(params)
|
||||
|
||||
if security:
|
||||
req.auth.cert = cert_name
|
||||
req.auth.sig = sign_protobuf_msg(req, tls_key)
|
||||
|
||||
ctx = sock.new_context()
|
||||
await ctx.asend(req.SerializeToString())
|
||||
|
||||
resp = SkynetRPCResponse()
|
||||
resp.ParseFromString(await ctx.arecv())
|
||||
ctx.close()
|
||||
|
||||
if security:
|
||||
verify_protobuf_msg(resp, skynet_cert)
|
||||
|
||||
return resp
|
||||
|
||||
yield _rpc_call
|
||||
|
||||
sesh = SessionClient(
|
||||
rpc_address,
|
||||
unique_id,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
)
|
||||
logging.debug(f'opening skynet rpc...')
|
||||
sesh.connect()
|
||||
yield sesh
|
||||
sesh.disconnect()
|
||||
|
||||
def validate_user_config_request(req: str):
|
||||
params = req.split(' ')
|
||||
|
|
|
@ -6,8 +6,6 @@ import logging
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
import pynng
|
||||
|
||||
from PIL import Image
|
||||
from trio_asyncio import aio_as_trio
|
||||
|
||||
|
@ -16,6 +14,7 @@ from telebot.types import (
|
|||
)
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
from ..db import open_database_connection
|
||||
from ..constants import *
|
||||
|
||||
from . import *
|
||||
|
@ -56,228 +55,274 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
|
|||
|
||||
|
||||
async def run_skynet_telegram(
|
||||
name: str,
|
||||
tg_token: str,
|
||||
key_name: str = 'telegram-frontend',
|
||||
cert_name: str = 'whitelist/telegram-frontend',
|
||||
rpc_address: str = DEFAULT_RPC_ADDR
|
||||
key_name: str = 'telegram-frontend.key',
|
||||
cert_name: str = 'whitelist/telegram-frontend.cert',
|
||||
rpc_address: str = DEFAULT_RPC_ADDR,
|
||||
db_host: str = 'localhost:5432',
|
||||
db_user: str = 'skynet',
|
||||
db_pass: str = 'password'
|
||||
):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
bot = AsyncTeleBot(tg_token)
|
||||
logging.info(f'tg_token: {tg_token}')
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'skynet-telegram-0',
|
||||
rpc_address=rpc_address,
|
||||
security=True,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
async with open_database_connection(
|
||||
db_user, db_pass, db_host
|
||||
) as db_call:
|
||||
with open_skynet_rpc(
|
||||
f'skynet-telegram-{name}',
|
||||
rpc_address=rpc_address,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as session:
|
||||
|
||||
async def _rpc_call(
|
||||
uid: int,
|
||||
method: str,
|
||||
params: dict = {}
|
||||
):
|
||||
return await rpc_call(
|
||||
method, params, uid=f'{PREFIX}+{uid}')
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
splt_msg = message.text.split(' ')
|
||||
|
||||
@bot.message_handler(commands=['help'])
|
||||
async def send_help(message):
|
||||
splt_msg = message.text.split(' ')
|
||||
|
||||
if len(splt_msg) == 1:
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
else:
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await bot.reply_to(message, HELP_TOPICS[param])
|
||||
if len(splt_msg) == 1:
|
||||
await bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
else:
|
||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
param = splt_msg[1]
|
||||
if param in HELP_TOPICS:
|
||||
await bot.reply_to(message, HELP_TOPICS[param])
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
else:
|
||||
await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
chat = message.chat
|
||||
@bot.message_handler(commands=['cool'])
|
||||
async def send_cool_words(message):
|
||||
await bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
async def send_txt2img(message):
|
||||
chat = message.chat
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
user_id = f'tg+{message.from_user.id}'
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'txt2img',
|
||||
{'prompt': prompt}
|
||||
)
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
prompt = ' '.join(message.text.split(' ')[1:])
|
||||
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
logging.info(f'mid: {message.id}')
|
||||
user = await db_call('get_or_create_user', user_id)
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
del user_config['id']
|
||||
|
||||
await bot.send_photo(
|
||||
GROUP_ID,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||
photo=img,
|
||||
reply_markup=build_redo_menu()
|
||||
resp = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
return
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
||||
async def send_img2img(message):
|
||||
chat = message.chat
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = resp.bin
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
|
||||
if not message.caption.startswith('/img2img'):
|
||||
return
|
||||
await bot.send_photo(
|
||||
GROUP_ID,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||
photo=img,
|
||||
reply_to_message_id=reply_id,
|
||||
reply_markup=build_redo_menu()
|
||||
)
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
@bot.message_handler(func=lambda message: True, content_types=['photo'])
|
||||
async def send_img2img(message):
|
||||
chat = message.chat
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = (await bot.get_file(file_id)).file_path
|
||||
file_raw = await bot.download_file(file_path)
|
||||
img = zlib.compress(file_raw)
|
||||
user_id = f'tg+{message.from_user.id}'
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'img2img',
|
||||
{'prompt': prompt, 'img': img.hex()}
|
||||
)
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
if not message.caption.startswith('/img2img'):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'For image to image you need to add /img2img to the beggining of your caption'
|
||||
)
|
||||
return
|
||||
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
await bot.send_media_group(
|
||||
GROUP_ID,
|
||||
media=[
|
||||
InputMediaPhoto(file_id),
|
||||
InputMediaPhoto(
|
||||
img,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
|
||||
)
|
||||
]
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = (await bot.get_file(file_id)).file_path
|
||||
file_raw = await bot.download_file(file_path)
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
|
||||
user = await db_call('get_or_create_user', user_id)
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
del user_config['id']
|
||||
|
||||
resp = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
},
|
||||
binext=file_raw,
|
||||
timeout=60
|
||||
)
|
||||
return
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(commands=['img2img'])
|
||||
async def redo_txt2img(message):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'seems you tried to do an img2img command without sending image'
|
||||
)
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = resp.bin
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
|
||||
async def _redo(message):
|
||||
resp = await _rpc_call(message.from_user.id, 'redo')
|
||||
await bot.send_media_group(
|
||||
GROUP_ID,
|
||||
media=[
|
||||
InputMediaPhoto(file_id),
|
||||
InputMediaPhoto(
|
||||
img,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
|
||||
)
|
||||
],
|
||||
reply_to_message_id=reply_id
|
||||
)
|
||||
return
|
||||
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
|
||||
await bot.send_photo(
|
||||
GROUP_ID,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||
photo=img,
|
||||
reply_markup=build_redo_menu()
|
||||
@bot.message_handler(commands=['img2img'])
|
||||
async def img2img_missing_image(message):
|
||||
await bot.reply_to(
|
||||
message,
|
||||
'seems you tried to do an img2img command without sending image'
|
||||
)
|
||||
return
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo(message):
|
||||
chat = message.chat
|
||||
reply_id = None
|
||||
if chat.type == 'group' and chat.id == GROUP_ID:
|
||||
reply_id = message.message_id
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo_txt2img(message):
|
||||
await _redo(message)
|
||||
user_config = {**(await db_call('get_user_config', user))}
|
||||
del user_config['id']
|
||||
prompt = await db_call('get_last_prompt_of', user)
|
||||
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
rpc_params = {}
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
message.text)
|
||||
resp = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': prompt,
|
||||
**user_config
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'config', {'attr': attr, 'val': val})
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = resp.bin
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
|
||||
finally:
|
||||
await bot.reply_to(message, reply_txt)
|
||||
await bot.send_photo(
|
||||
GROUP_ID,
|
||||
caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
|
||||
photo=img,
|
||||
reply_to_message_id=reply_id
|
||||
)
|
||||
return
|
||||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
async def user_stats(message):
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'stats',
|
||||
{}
|
||||
)
|
||||
stats = resp.result
|
||||
@bot.message_handler(commands=['config'])
|
||||
async def set_config(message):
|
||||
rpc_params = {}
|
||||
try:
|
||||
attr, val, reply_txt = validate_user_config_request(
|
||||
message.text)
|
||||
|
||||
stats_str = f'generated: {stats["generated"]}\n'
|
||||
stats_str += f'joined: {stats["joined"]}\n'
|
||||
stats_str += f'role: {stats["role"]}\n'
|
||||
logging.info(f'user config update: {attr} to {val}')
|
||||
await db_call('update_user_config',
|
||||
user, req.params['attr'], req.params['val'])
|
||||
logging.info('done')
|
||||
|
||||
await bot.reply_to(
|
||||
message, stats_str)
|
||||
except BaseException as e:
|
||||
reply_txt = str(e)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
async def donation_info(message):
|
||||
await bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
finally:
|
||||
await bot.reply_to(message, reply_txt)
|
||||
|
||||
@bot.message_handler(commands=['say'])
|
||||
async def say(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
@bot.message_handler(commands=['stats'])
|
||||
async def user_stats(message):
|
||||
|
||||
if (chat.type == 'group') or (user.id != 383385940):
|
||||
return
|
||||
generated, joined, role = await db_call('get_user_stats', user)
|
||||
|
||||
await bot.send_message(GROUP_ID, message.text[4:])
|
||||
stats_str = f'generated: {generated}\n'
|
||||
stats_str += f'joined: {joined}\n'
|
||||
stats_str += f'role: {role}\n'
|
||||
|
||||
await bot.reply_to(
|
||||
message, stats_str)
|
||||
|
||||
@bot.message_handler(commands=['donate'])
|
||||
async def donation_info(message):
|
||||
await bot.reply_to(
|
||||
message, DONATION_INFO)
|
||||
|
||||
@bot.message_handler(commands=['say'])
|
||||
async def say(message):
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
if (chat.type == 'group') or (user.id != 383385940):
|
||||
return
|
||||
|
||||
await bot.send_message(GROUP_ID, message.text[4:])
|
||||
|
||||
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
async def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||
@bot.message_handler(func=lambda message: True)
|
||||
async def echo_message(message):
|
||||
if message.text[0] == '/':
|
||||
await bot.reply_to(message, UNKNOWN_CMD_TEXT)
|
||||
|
||||
@bot.callback_query_handler(func=lambda call: True)
|
||||
async def callback_query(call):
|
||||
|
@ -289,4 +334,4 @@ async def run_skynet_telegram(
|
|||
await _redo(call)
|
||||
|
||||
|
||||
await aio_as_trio(bot.infinity_polling())
|
||||
await aio_as_trio(bot.infinity_polling)()
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import zlib
|
||||
import socket
|
||||
|
||||
from typing import Callable, Awaitable, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
import trio
|
||||
import pynng
|
||||
|
||||
from pynng import TLSConfig, Context
|
||||
|
||||
from .protobuf import *
|
||||
from .constants import *
|
||||
|
||||
|
||||
def get_random_port():
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def load_certs(
|
||||
certs_dir: str,
|
||||
cert_name: str,
|
||||
key_name: str
|
||||
):
|
||||
certs_dir = Path(certs_dir).resolve()
|
||||
tls_key_data = (certs_dir / key_name).read_bytes()
|
||||
tls_key = serialization.load_pem_private_key(
|
||||
tls_key_data,
|
||||
password=None
|
||||
)
|
||||
|
||||
tls_cert_data = (certs_dir / cert_name).read_bytes()
|
||||
tls_cert = x509.load_pem_x509_certificate(
|
||||
tls_cert_data
|
||||
)
|
||||
|
||||
tls_whitelist = {}
|
||||
for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
|
||||
tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
|
||||
cert_path.read_bytes()
|
||||
)
|
||||
|
||||
return (
|
||||
SessionTLSConfig(
|
||||
TLSConfig.MODE_SERVER,
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data
|
||||
),
|
||||
|
||||
tls_whitelist
|
||||
)
|
||||
|
||||
|
||||
def load_certs_client(
|
||||
certs_dir: str,
|
||||
cert_name: str,
|
||||
key_name: str,
|
||||
ca_name: Optional[str] = None
|
||||
):
|
||||
certs_dir = Path(certs_dir).resolve()
|
||||
if not ca_name:
|
||||
ca_name = 'brain.cert'
|
||||
|
||||
ca_cert_data = (certs_dir / ca_name).read_bytes()
|
||||
|
||||
tls_key_data = (certs_dir / key_name).read_bytes()
|
||||
|
||||
|
||||
tls_cert_data = (certs_dir / cert_name).read_bytes()
|
||||
|
||||
|
||||
tls_whitelist = {}
|
||||
for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
|
||||
tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
|
||||
cert_path.read_bytes()
|
||||
)
|
||||
|
||||
return (
|
||||
SessionTLSConfig(
|
||||
TLSConfig.MODE_CLIENT,
|
||||
own_key_string=tls_key_data,
|
||||
own_cert_string=tls_cert_data,
|
||||
ca_string=ca_cert_data
|
||||
),
|
||||
|
||||
tls_whitelist
|
||||
)
|
||||
|
||||
|
||||
class SessionError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
class SessionTLSConfig(TLSConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode,
|
||||
server_name=None,
|
||||
ca_string=None,
|
||||
own_key_string=None,
|
||||
own_cert_string=None,
|
||||
auth_mode=None,
|
||||
ca_files=None,
|
||||
cert_key_file=None,
|
||||
passwd=None
|
||||
):
|
||||
super().__init__(
|
||||
mode,
|
||||
server_name=server_name,
|
||||
ca_string=ca_string,
|
||||
own_key_string=own_key_string,
|
||||
own_cert_string=own_cert_string,
|
||||
auth_mode=auth_mode,
|
||||
ca_files=ca_files,
|
||||
cert_key_file=cert_key_file,
|
||||
passwd=passwd
|
||||
)
|
||||
|
||||
if ca_string:
|
||||
self.ca_cert = x509.load_pem_x509_certificate(ca_string)
|
||||
|
||||
self.cert = x509.load_pem_x509_certificate(own_cert_string)
|
||||
self.key = serialization.load_pem_private_key(
|
||||
own_key_string,
|
||||
password=passwd
|
||||
)
|
||||
|
||||
|
||||
class SessionServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addr: str,
|
||||
msg_handler: Callable[
|
||||
[SkynetRPCRequest, Context], Awaitable[SkynetRPCResponse]
|
||||
],
|
||||
cert_name: Optional[str] = None,
|
||||
key_name: Optional[str] = None,
|
||||
cert_dir: str = DEFAULT_CERTS_DIR,
|
||||
recv_max_size = 0
|
||||
):
|
||||
self.addr = addr
|
||||
self.msg_handler = msg_handler
|
||||
|
||||
self.cert_name = cert_name
|
||||
self.tls_config = None
|
||||
self.tls_whitelist = None
|
||||
if cert_name and key_name:
|
||||
self.cert_name = cert_name
|
||||
self.tls_config, self.tls_whitelist = load_certs(
|
||||
cert_dir, cert_name, key_name)
|
||||
|
||||
self.addr = 'tls+' + self.addr
|
||||
|
||||
self.recv_max_size = recv_max_size
|
||||
|
||||
async def _handle_msg(self, req: SkynetRPCRequest, ctx: Context):
|
||||
resp = await self.msg_handler(req, ctx)
|
||||
|
||||
if self.tls_config:
|
||||
resp.auth.cert = 'skynet'
|
||||
resp.auth.sig = sign_protobuf_msg(
|
||||
resp, self.tls_config.key)
|
||||
|
||||
raw_msg = zlib.compress(resp.SerializeToString())
|
||||
|
||||
await ctx.asend(raw_msg)
|
||||
|
||||
ctx.close()
|
||||
|
||||
async def _listener (self, sock):
|
||||
async with trio.open_nursery() as n:
|
||||
while True:
|
||||
ctx = sock.new_context()
|
||||
|
||||
raw_msg = await ctx.arecv()
|
||||
raw_size = len(raw_msg)
|
||||
logging.debug(f'rpc server new msg {raw_size} bytes')
|
||||
|
||||
try:
|
||||
msg = zlib.decompress(raw_msg)
|
||||
msg_size = len(msg)
|
||||
|
||||
except zlib.error:
|
||||
logging.warning(f'Zlib decompress error, dropping msg of size {len(raw_msg)}')
|
||||
continue
|
||||
|
||||
logging.debug(f'msg after decompress {msg_size} bytes, +{msg_size - raw_size} bytes')
|
||||
|
||||
req = SkynetRPCRequest()
|
||||
try:
|
||||
req.ParseFromString(msg)
|
||||
|
||||
except google.protobuf.message.DecodeError:
|
||||
logging.warning(f'Dropping malfomed msg of size {len(msg)}')
|
||||
continue
|
||||
|
||||
logging.debug(f'msg method: {req.method}')
|
||||
|
||||
if self.tls_config:
|
||||
if req.auth.cert not in self.tls_whitelist:
|
||||
logging.warning(
|
||||
f'{req.auth.cert} not in tls whitelist')
|
||||
continue
|
||||
|
||||
try:
|
||||
verify_protobuf_msg(req, self.tls_whitelist[req.auth.cert])
|
||||
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f'{req.cert} sent an unauthenticated msg')
|
||||
continue
|
||||
|
||||
n.start_soon(self._handle_msg, req, ctx)
|
||||
|
||||
@acm
|
||||
async def open(self):
|
||||
with pynng.Rep0(
|
||||
recv_max_size=self.recv_max_size
|
||||
) as sock:
|
||||
|
||||
if self.tls_config:
|
||||
sock.tls_config = self.tls_config
|
||||
|
||||
sock.listen(self.addr)
|
||||
|
||||
logging.debug(f'server socket listening at {self.addr}')
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(self._listener, sock)
|
||||
|
||||
try:
|
||||
yield self
|
||||
|
||||
finally:
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
logging.debug('server socket is off.')
|
||||
|
||||
|
||||
class SessionClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connect_addr: str,
|
||||
uid: str,
|
||||
cert_name: Optional[str] = None,
|
||||
key_name: Optional[str] = None,
|
||||
ca_name: Optional[str] = None,
|
||||
cert_dir: str = DEFAULT_CERTS_DIR,
|
||||
recv_max_size = 0
|
||||
):
|
||||
self.uid = uid
|
||||
self.connect_addr = connect_addr
|
||||
|
||||
self.cert_name = None
|
||||
self.tls_config = None
|
||||
self.tls_whitelist = None
|
||||
self.tls_cert = None
|
||||
self.tls_key = None
|
||||
if cert_name and key_name:
|
||||
self.cert_name = Path(cert_name).stem
|
||||
self.tls_config, self.tls_whitelist = load_certs_client(
|
||||
cert_dir, cert_name, key_name, ca_name=ca_name)
|
||||
|
||||
if not self.connect_addr.startswith('tls'):
|
||||
self.connect_addr = 'tls+' + self.connect_addr
|
||||
|
||||
self.recv_max_size = recv_max_size
|
||||
|
||||
self._connected = False
|
||||
self._sock = None
|
||||
|
||||
def connect(self):
|
||||
self._sock = pynng.Req0(
|
||||
recv_max_size=0,
|
||||
name=self.uid
|
||||
)
|
||||
|
||||
if self.tls_config:
|
||||
self._sock.tls_config = self.tls_config
|
||||
|
||||
logging.debug(f'client is dialing {self.connect_addr}...')
|
||||
self._sock.dial(self.connect_addr, block=True)
|
||||
self._connected = True
|
||||
logging.debug(f'client is connected to {self.connect_addr}')
|
||||
|
||||
def disconnect(self):
|
||||
self._sock.close()
|
||||
self._connected = False
|
||||
logging.debug(f'client disconnected.')
|
||||
|
||||
async def rpc(
|
||||
self,
|
||||
method: str,
|
||||
params: dict = {},
|
||||
binext: Optional[bytes] = None,
|
||||
timeout: float = 2.
|
||||
):
|
||||
if not self._connected:
|
||||
raise SessionError('tried to use rpc without connecting')
|
||||
|
||||
req = SkynetRPCRequest()
|
||||
req.uid = self.uid
|
||||
req.method = method
|
||||
req.params.update(params)
|
||||
if binext:
|
||||
logging.debug('added binary extension')
|
||||
req.bin = binext
|
||||
|
||||
if self.tls_config:
|
||||
req.auth.cert = self.cert_name
|
||||
req.auth.sig = sign_protobuf_msg(req, self.tls_config.key)
|
||||
|
||||
with trio.fail_after(timeout):
|
||||
ctx = self._sock.new_context()
|
||||
raw_req = zlib.compress(req.SerializeToString())
|
||||
logging.debug(f'rpc client sending new msg {method} of size {len(raw_req)}')
|
||||
await ctx.asend(raw_req)
|
||||
logging.debug('sent, awaiting response...')
|
||||
raw_resp = await ctx.arecv()
|
||||
logging.debug(f'rpc client got response of size {len(raw_resp)}')
|
||||
raw_resp = zlib.decompress(raw_resp)
|
||||
|
||||
resp = SkynetRPCResponse()
|
||||
resp.ParseFromString(raw_resp)
|
||||
ctx.close()
|
||||
|
||||
if self.tls_config:
|
||||
verify_protobuf_msg(resp, self.tls_config.ca_cert)
|
||||
|
||||
return resp
|
|
@ -1,29 +1,4 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from .auth import *
|
||||
from .skynet_pb2 import *
|
||||
|
||||
|
||||
class Struct:
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionParameters(Struct):
|
||||
algo: str
|
||||
prompt: str
|
||||
step: int
|
||||
width: int
|
||||
height: int
|
||||
guidance: float
|
||||
strength: float
|
||||
seed: Optional[int]
|
||||
image: bool # if true indicates a bytestream is next msg
|
||||
upscaler: Optional[str]
|
||||
|
|
|
@ -7,7 +7,8 @@ from hashlib import sha256
|
|||
from collections import OrderedDict
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||
from cryptography.hazmat.primitives import serialization, hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
from .skynet_pb2 import *
|
||||
|
||||
|
@ -46,20 +47,23 @@ def serialize_msg_deterministic(msg):
|
|||
if field_descriptor.message_type.name == 'Struct':
|
||||
hash_dict(MessageToDict(getattr(msg, field_name)))
|
||||
|
||||
deterministic_msg = shasum.hexdigest()
|
||||
deterministic_msg = shasum.digest()
|
||||
|
||||
return deterministic_msg
|
||||
|
||||
|
||||
def sign_protobuf_msg(msg, key: PKey):
|
||||
return sign(
|
||||
key, serialize_msg_deterministic(msg), 'sha256').hex()
|
||||
def sign_protobuf_msg(msg, key):
|
||||
return key.sign(
|
||||
serialize_msg_deterministic(msg),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256()
|
||||
).hex()
|
||||
|
||||
|
||||
def verify_protobuf_msg(msg, cert: X509):
|
||||
return verify(
|
||||
cert,
|
||||
def verify_protobuf_msg(msg, cert):
|
||||
return cert.public_key().verify(
|
||||
bytes.fromhex(msg.auth.sig),
|
||||
serialize_msg_deterministic(msg),
|
||||
'sha256'
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256()
|
||||
)
|
||||
|
|
|
@ -13,18 +13,12 @@ message SkynetRPCRequest {
|
|||
string uid = 1;
|
||||
string method = 2;
|
||||
google.protobuf.Struct params = 3;
|
||||
optional Auth auth = 4;
|
||||
optional bytes bin = 4;
|
||||
optional Auth auth = 5;
|
||||
}
|
||||
|
||||
message SkynetRPCResponse {
|
||||
google.protobuf.Struct result = 1;
|
||||
optional Auth auth = 2;
|
||||
}
|
||||
|
||||
message DGPUBusMessage {
|
||||
string rid = 1;
|
||||
string nid = 2;
|
||||
string method = 3;
|
||||
google.protobuf.Struct params = 4;
|
||||
optional Auth auth = 5;
|
||||
optional bytes bin = 2;
|
||||
optional Auth auth = 3;
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|||
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x9c\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_auth\"\x80\x01\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x03 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_authb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
|
||||
|
@ -24,9 +24,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|||
_AUTH._serialized_start=54
|
||||
_AUTH._serialized_end=87
|
||||
_SKYNETRPCREQUEST._serialized_start=90
|
||||
_SKYNETRPCREQUEST._serialized_end=220
|
||||
_SKYNETRPCRESPONSE._serialized_start=222
|
||||
_SKYNETRPCRESPONSE._serialized_end=324
|
||||
_DGPUBUSMESSAGE._serialized_start=327
|
||||
_DGPUBUSMESSAGE._serialized_end=468
|
||||
_SKYNETRPCREQUEST._serialized_end=246
|
||||
_SKYNETRPCRESPONSE._serialized_start=249
|
||||
_SKYNETRPCRESPONSE._serialized_end=377
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import time
|
||||
import random
|
||||
|
||||
from typing import Optional
|
||||
|
@ -21,6 +22,10 @@ from huggingface_hub import login
|
|||
from .constants import ALGOS
|
||||
|
||||
|
||||
def time_ms():
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
||||
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
return Image.fromarray(img)
|
||||
|
@ -164,3 +169,13 @@ def upscale(
|
|||
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def download_all_models(hf_token: str):
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
login(token=hf_token)
|
||||
for model in ALGOS:
|
||||
print(f'DOWNLOADING {model.upper()}')
|
||||
pipeline_for(model)
|
||||
|
||||
|
|
|
@ -3,89 +3,30 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import logging
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import pytest
|
||||
import psycopg2
|
||||
import trio_asyncio
|
||||
|
||||
from docker.types import Mount, DeviceRequest
|
||||
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
||||
|
||||
from skynet.constants import *
|
||||
from skynet.db import open_new_database
|
||||
from skynet.brain import run_skynet
|
||||
from skynet.network import get_random_port
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def postgres_db(dockerctl):
|
||||
rpassword = ''.join(
|
||||
random.choice(string.ascii_lowercase)
|
||||
for i in range(12))
|
||||
password = ''.join(
|
||||
random.choice(string.ascii_lowercase)
|
||||
for i in range(12))
|
||||
|
||||
with dockerctl.run(
|
||||
'postgres',
|
||||
name='skynet-test-postgres',
|
||||
ports={'5432/tcp': None},
|
||||
environment={
|
||||
'POSTGRES_PASSWORD': rpassword
|
||||
}
|
||||
) as containers:
|
||||
container = containers[0]
|
||||
# ip = container.attrs['NetworkSettings']['IPAddress']
|
||||
port = container.ports['5432/tcp'][0]['HostPort']
|
||||
host = f'localhost:{port}'
|
||||
|
||||
for log in container.logs(stream=True):
|
||||
log = log.decode().rstrip()
|
||||
logging.info(log)
|
||||
if ('database system is ready to accept connections' in log or
|
||||
'database system is shut down' in log):
|
||||
break
|
||||
|
||||
# why print the system is ready to accept connections when its not
|
||||
# postgres? wtf
|
||||
time.sleep(1)
|
||||
logging.info('creating skynet db...')
|
||||
|
||||
conn = psycopg2.connect(
|
||||
user='postgres',
|
||||
password=rpassword,
|
||||
host='localhost',
|
||||
port=port
|
||||
)
|
||||
logging.info('connected...')
|
||||
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
|
||||
cursor.execute(
|
||||
f'CREATE DATABASE {DB_NAME}')
|
||||
cursor.execute(
|
||||
f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
|
||||
|
||||
conn.close()
|
||||
|
||||
logging.info('done.')
|
||||
yield container, password, host
|
||||
with open_new_database() as db_params:
|
||||
yield db_params
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def skynet_running(postgres_db):
|
||||
db_container, db_pass, db_host = postgres_db
|
||||
|
||||
async with run_skynet(
|
||||
db_pass=db_pass,
|
||||
db_host=db_host
|
||||
):
|
||||
async def skynet_running():
|
||||
async with run_skynet():
|
||||
yield
|
||||
|
||||
|
||||
|
@ -99,11 +40,13 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
|||
|
||||
cmds = []
|
||||
for i in range(num_containers):
|
||||
dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
|
||||
cmd = f'''
|
||||
pip install -e . && \
|
||||
skynet run dgpu \
|
||||
--algos=\'{json.dumps(initial_algos)}\' \
|
||||
--uid=dgpu-{i}
|
||||
--uid=dgpu-{i} \
|
||||
--dgpu={dgpu_addr}
|
||||
'''
|
||||
cmds.append(['bash', '-c', cmd])
|
||||
|
||||
|
@ -114,16 +57,15 @@ def dgpu_workers(request, dockerctl, skynet_running):
|
|||
name='skynet-test-runtime-cuda',
|
||||
commands=cmds,
|
||||
environment={
|
||||
'HF_TOKEN': os.environ['HF_TOKEN'],
|
||||
'HF_HOME': '/skynet/hf_home'
|
||||
},
|
||||
network='host',
|
||||
mounts=mounts,
|
||||
device_requests=devices,
|
||||
num=num_containers
|
||||
num=num_containers,
|
||||
) as containers:
|
||||
yield containers
|
||||
|
||||
#for i, container in enumerate(containers):
|
||||
# logging.info(f'container {i} logs:')
|
||||
# logging.info(container.logs().decode())
|
||||
for i, container in enumerate(containers):
|
||||
logging.info(f'container {i} logs:')
|
||||
logging.info(container.logs().decode())
|
||||
|
|
|
@ -12,29 +12,26 @@ from functools import partial
|
|||
|
||||
import trio
|
||||
import pytest
|
||||
import trio_asyncio
|
||||
|
||||
from PIL import Image
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from skynet.brain import SkynetDGPUComputeError
|
||||
from skynet.constants import *
|
||||
from skynet.network import get_random_port, SessionServer
|
||||
from skynet.protobuf import SkynetRPCResponse
|
||||
from skynet.frontend import open_skynet_rpc
|
||||
from skynet.constants import *
|
||||
|
||||
|
||||
async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||
async def wait_for_dgpus(session, amount: int, timeout: float = 30.0):
|
||||
gpu_ready = False
|
||||
start_time = time.time()
|
||||
current_time = time.time()
|
||||
while not gpu_ready and (current_time - start_time) < timeout:
|
||||
res = await rpc('dgpu_workers')
|
||||
if res.result['ok'] >= amount:
|
||||
break
|
||||
with trio.fail_after(timeout):
|
||||
while not gpu_ready:
|
||||
res = await session.rpc('dgpu_workers')
|
||||
if res.result['ok'] >= amount:
|
||||
break
|
||||
|
||||
await trio.sleep(1)
|
||||
current_time = time.time()
|
||||
|
||||
assert (current_time - start_time) < timeout
|
||||
await trio.sleep(1)
|
||||
|
||||
|
||||
_images = set()
|
||||
|
@ -48,34 +45,33 @@ async def check_request_img(
|
|||
):
|
||||
global _images
|
||||
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
uid,
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
res = await rpc_call(
|
||||
'txt2img', {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': width, 'height': height,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[i],
|
||||
'upscaler': upscaler
|
||||
})
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
res = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': width, 'height': height,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[i],
|
||||
'upscaler': upscaler
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
if upscaler == 'x4':
|
||||
width *= 4
|
||||
height *= 4
|
||||
|
||||
img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
|
||||
img_raw = res.bin
|
||||
img_sha = sha256(img_raw).hexdigest()
|
||||
img = Image.frombytes(
|
||||
'RGB', (width, height), img_raw)
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
|
||||
if expect_unique and img_sha in _images:
|
||||
raise ValueError('Duplicated image sha: {img_sha}')
|
||||
|
@ -96,13 +92,12 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
|||
then generate a smaller image to show gpu worker recovery
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0, width=4096, height=4096)
|
||||
|
@ -112,20 +107,35 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
|
|||
await check_request_img(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_worker(dgpu_workers):
|
||||
'''Generate one image in a single dgpu worker
|
||||
'''
|
||||
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
|
||||
async def test_dgpu_workers(dgpu_workers):
|
||||
async def test_dgpu_worker_two_models(dgpu_workers):
|
||||
'''Generate two images in a single dgpu worker using
|
||||
two different models.
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
await check_request_img(0)
|
||||
await check_request_img(1)
|
||||
|
@ -138,14 +148,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
|
|||
two different models.
|
||||
'''
|
||||
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
logging.error('UPSCALE')
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
img = await check_request_img(0, upscaler='x4')
|
||||
|
||||
|
@ -157,13 +165,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
|
|||
async def test_dgpu_workers_two(dgpu_workers):
|
||||
'''Generate two images in two separate dgpu workers
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 2)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 2, timeout=60)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(check_request_img, 0)
|
||||
|
@ -175,13 +182,12 @@ async def test_dgpu_workers_two(dgpu_workers):
|
|||
async def test_dgpu_worker_algo_swap(dgpu_workers):
|
||||
'''Generate an image using a non default model
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
await check_request_img(5)
|
||||
|
||||
|
||||
|
@ -191,33 +197,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
|
|||
'''Connect three dgpu workers, disconnect and check next_worker
|
||||
rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 3)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 3)
|
||||
|
||||
res = await test_rpc('dgpu_next')
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('dgpu_next')
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 1
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('dgpu_next')
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
await check_request_img(0)
|
||||
|
||||
res = await test_rpc('dgpu_next')
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
@ -228,13 +233,12 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|||
'''Connect three dgpu workers, disconnect the first one and check
|
||||
next_worker rotation happens correctly
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 3)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 3)
|
||||
|
||||
await trio.sleep(3)
|
||||
|
||||
|
@ -245,7 +249,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
|
|||
|
||||
dgpu_workers[0].wait()
|
||||
|
||||
res = await test_rpc('dgpu_workers')
|
||||
res = await session.rpc('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 2
|
||||
|
||||
|
@ -258,26 +262,43 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
|
|||
'''Mock a node that connects, gets a request but fails to
|
||||
acknowledge it, then check skynet correctly drops the node
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
async def mock_rpc(req, ctx):
|
||||
resp = SkynetRPCResponse()
|
||||
resp.result.update({'error': 'can\'t do it mate'})
|
||||
return resp
|
||||
|
||||
await wait_for_dgpus(rpc_call, 1)
|
||||
dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
|
||||
mock_server = SessionServer(
|
||||
dgpu_addr,
|
||||
mock_rpc,
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
)
|
||||
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0)
|
||||
async with mock_server.open():
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
|
||||
assert 'dgpu failed to acknowledge request' in str(e)
|
||||
res = await session.rpc('dgpu_online', {
|
||||
'dgpu_addr': dgpu_addr,
|
||||
'cert': 'whitelist/testing.cert'
|
||||
})
|
||||
assert 'ok' in res.result
|
||||
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
await check_request_img(0)
|
||||
|
||||
assert 'can\'t do it mate' in str(e.value)
|
||||
|
||||
res = await session.rpc('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -286,13 +307,12 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
|
|||
'''Stop node while processing request to cause timeout and
|
||||
then check skynet correctly drops the node.
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
async def check_request_img_raises():
|
||||
with pytest.raises(SkynetDGPUComputeError) as e:
|
||||
|
@ -308,72 +328,62 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
|
|||
assert ec == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_heartbeat(dgpu_workers):
|
||||
'''
|
||||
'''
|
||||
async with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
await trio.sleep(120)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_img2img(dgpu_workers):
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'1',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
await wait_for_dgpus(rpc_call, 1)
|
||||
with open_skynet_rpc(
|
||||
'test-ctx',
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
await wait_for_dgpus(session, 1)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
res = await rpc_call(
|
||||
'txt2img', {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': 512, 'height': 512,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': None
|
||||
})
|
||||
res = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': 512, 'height': 512,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': None
|
||||
}
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
img_raw = res.result['img']
|
||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||
logging.info(img[:10])
|
||||
img = Image.open(io.BytesIO(img))
|
||||
|
||||
img_raw = res.bin
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
img.save('txt2img.png')
|
||||
|
||||
res = await rpc_call(
|
||||
'img2img', {
|
||||
'prompt': 'red sports car in a sunny wheat field',
|
||||
'step': 28,
|
||||
'img': img_raw,
|
||||
'guidance': 12,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': 'x4'
|
||||
})
|
||||
res = await session.rpc(
|
||||
'dgpu_call', {
|
||||
'method': 'diffuse',
|
||||
'params': {
|
||||
'prompt': 'red ferrari in a sunny wheat field',
|
||||
'step': 28,
|
||||
'guidance': 8,
|
||||
'strength': 0.7,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': 'x4'
|
||||
}
|
||||
},
|
||||
binext=img_raw,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
img_raw = res.result['img']
|
||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||
logging.info(img[:10])
|
||||
img = Image.open(io.BytesIO(img))
|
||||
|
||||
img_raw = res.bin
|
||||
img = Image.open(io.BytesIO(img_raw))
|
||||
img.save('img2img.png')
|
||||
|
|
|
@ -9,6 +9,7 @@ import trio_asyncio
|
|||
|
||||
from skynet.brain import run_skynet
|
||||
from skynet.structs import *
|
||||
from skynet.network import SessionServer
|
||||
from skynet.frontend import open_skynet_rpc
|
||||
|
||||
|
||||
|
@ -18,53 +19,68 @@ async def test_skynet(skynet_running):
|
|||
|
||||
async def test_skynet_attempt_insecure(skynet_running):
|
||||
with pytest.raises(pynng.exceptions.NNGException) as e:
|
||||
async with open_skynet_rpc('bad-actor'):
|
||||
...
|
||||
|
||||
assert str(e.value) == 'Connection shutdown'
|
||||
with open_skynet_rpc('bad-actor') as session:
|
||||
with trio.fail_after(5):
|
||||
await session.rpc('skynet_shutdown')
|
||||
|
||||
|
||||
async def test_skynet_dgpu_connection_simple(skynet_running):
|
||||
async with open_skynet_rpc(
|
||||
|
||||
async def rpc_handler(req, ctx):
|
||||
...
|
||||
|
||||
fake_dgpu_addr = 'tcp://127.0.0.1:41001'
|
||||
rpc_server = SessionServer(
|
||||
fake_dgpu_addr,
|
||||
rpc_handler,
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
)
|
||||
|
||||
with open_skynet_rpc(
|
||||
'dgpu-0',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
cert_name='whitelist/testing.cert',
|
||||
key_name='testing.key'
|
||||
) as session:
|
||||
# check 0 nodes are connected
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
res = await session.rpc('dgpu_workers')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# check next worker is None
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == None
|
||||
|
||||
# connect 1 dgpu
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
async with rpc_server.open() as rpc_server:
|
||||
# connect 1 dgpu
|
||||
res = await session.rpc(
|
||||
'dgpu_online', {
|
||||
'dgpu_addr': fake_dgpu_addr,
|
||||
'cert': 'whitelist/testing.cert'
|
||||
})
|
||||
assert 'ok' in res.result.keys()
|
||||
|
||||
# check 1 node is connected
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 1
|
||||
# check 1 node is connected
|
||||
res = await session.rpc('dgpu_workers')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == 1
|
||||
|
||||
# check next worker is 0
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
assert res.result['ok'] == 0
|
||||
# check next worker is 0
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# disconnect 1 dgpu
|
||||
res = await rpc_call('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
# disconnect 1 dgpu
|
||||
res = await session.rpc('dgpu_offline')
|
||||
assert 'ok' in res.result.keys()
|
||||
|
||||
# check 0 nodes are connected
|
||||
res = await rpc_call('dgpu_workers')
|
||||
assert 'ok' in res.result
|
||||
res = await session.rpc('dgpu_workers')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == 0
|
||||
|
||||
# check next worker is None
|
||||
res = await rpc_call('dgpu_next')
|
||||
assert 'ok' in res.result
|
||||
res = await session.rpc('dgpu_next')
|
||||
assert 'ok' in res.result.keys()
|
||||
assert res.result['ok'] == None
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import trio
|
||||
|
||||
from functools import partial
|
||||
|
||||
from skynet.db import open_new_database
|
||||
from skynet.brain import run_skynet
|
||||
from skynet.config import load_skynet_ini
|
||||
from skynet.frontend.telegram import run_skynet_telegram
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''You will need a telegram bot token configured on skynet.ini for this
|
||||
'''
|
||||
with open_new_database() as db_params:
|
||||
db_container, db_pass, db_host = db_params
|
||||
config = load_skynet_ini()
|
||||
|
||||
async def main():
|
||||
await run_skynet_telegram(
|
||||
'telegram-test',
|
||||
config['skynet.telegram-test']['token'],
|
||||
db_host=db_host,
|
||||
db_pass=db_pass
|
||||
)
|
||||
|
||||
trio.run(main)
|
Loading…
Reference in New Issue