diff --git a/.gitignore b/.gitignore
index e264fa9..e98d124 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+skynet.ini
.python-version
hf_home
outputs
diff --git a/Dockerfile.runtime+cuda b/Dockerfile.runtime+cuda
index 32d3c4a..27a5a66 100644
--- a/Dockerfile.runtime+cuda
+++ b/Dockerfile.runtime+cuda
@@ -32,3 +32,4 @@ env HF_HOME /hf_home
copy scripts scripts
copy tests tests
+expose 40000-45000
diff --git a/LICENSE b/LICENSE
index 0fb76a9..fe6b903 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,11 +1,662 @@
-A menos que sea especificamente indicado en el cabezal del archivo, se reservan
-todos los derechos sobre este codigo por parte de:
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
-Guillermo Rodriguez, guillermor@fing.edu.uy
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
-ENGLISH LICENSE:
+ Preamble
-Unless specifically indicated in the file header, all rights to this code are
-reserved by:
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
-Guillermo Rodriguez, guillermor@.edu.uy
diff --git a/requirements.test.txt b/requirements.test.txt
index ce51dec..f39926b 100644
--- a/requirements.test.txt
+++ b/requirements.test.txt
@@ -3,4 +3,4 @@ pytest
pytest-trio
psycopg2-binary
-git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl
+git+https://github.com/guilledk/pytest-dockerctl.git@multi_names#egg=pytest-dockerctl
diff --git a/requirements.txt b/requirements.txt
index 7afc143..c773225 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,3 +9,5 @@ protobuf
pyOpenSSL
trio_asyncio
pyTelegramBotAPI
+
+git+https://github.com/goodboy/tractor.git@master#egg=tractor
diff --git a/skynet.ini.example b/skynet.ini.example
new file mode 100644
index 0000000..7035920
--- /dev/null
+++ b/skynet.ini.example
@@ -0,0 +1,12 @@
+[skynet]
+certs_dir = certs
+
+[skynet.dgpu]
+hf_home = hf_home
+hf_token = hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx
+
+[skynet.telegram]
+token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
+
+[skynet.telegram-test]
+token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
diff --git a/skynet/brain.py b/skynet/brain.py
index c442ba5..b121bd3 100644
--- a/skynet/brain.py
+++ b/skynet/brain.py
@@ -1,35 +1,24 @@
#!/usr/bin/python
-import time
-import json
-import uuid
-import zlib
import logging
-import traceback
-from uuid import UUID
-from pathlib import Path
-from functools import partial
from contextlib import asynccontextmanager as acm
from collections import OrderedDict
import trio
-import pynng
-import trio_asyncio
-from pynng import TLSConfig
-from OpenSSL.crypto import (
- load_privatekey,
- load_certificate,
- FILETYPE_PEM
-)
+from pynng import Context
-from .db import *
+from .utils import time_ms
+from .network import *
+from .protobuf import *
from .constants import *
-from .protobuf import *
+class SkynetRPCBadRequest(BaseException):
+ ...
+
class SkynetDGPUOffline(BaseException):
...
@@ -44,39 +33,71 @@ class SkynetShutdownRequested(BaseException):
@acm
-async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
+async def run_skynet(
+ rpc_address: str = DEFAULT_RPC_ADDR
+):
+ logging.basicConfig(level=logging.INFO)
+ logging.info('skynet is starting')
+
nodes = OrderedDict()
- wip_reqs = {}
- fin_reqs = {}
heartbeats = {}
next_worker: Optional[int] = None
- security = len(tls_whitelist) > 0
- def connect_node(uid):
+ def connect_node(req: SkynetRPCRequest):
nonlocal next_worker
- nodes[uid] = {
- 'task': None
- }
- logging.info(f'dgpu online: {uid}')
- if not next_worker:
- next_worker = 0
+ node_params = MessageToDict(req.params)
+ logging.info(f'got node params {node_params}')
+
+ if 'dgpu_addr' not in node_params:
+ raise SkynetRPCBadRequest(
+ f'DGPU connection params don\'t include dgpu addr')
+
+ session = SessionClient(
+ node_params['dgpu_addr'],
+ 'skynet',
+ cert_name='brain.cert',
+ key_name='brain.key',
+ ca_name=node_params['cert']
+ )
+ try:
+ session.connect()
+
+ node = {
+ 'task': None,
+ 'session': session
+ }
+ node.update(node_params)
+
+ nodes[req.uid] = node
+ logging.info(f'DGPU node online: {req.uid}')
+
+ if not next_worker:
+ next_worker = 0
+
+ except pynng.exceptions.ConnectionRefused:
+ logging.warning(f'error while dialing dgpu node... dropping...')
+ raise SkynetDGPUOffline('Connection to dgpu node addr failed.')
def disconnect_node(uid):
nonlocal next_worker
if uid not in nodes:
+ logging.warning(f'Attempt to disconnect unknown node {uid}')
return
+
i = list(nodes.keys()).index(uid)
+ nodes[uid]['session'].disconnect()
del nodes[uid]
if i < next_worker:
next_worker -= 1
+ logging.warning(f'DGPU node offline: {uid}')
+
if len(nodes) == 0:
- logging.info('nw: None')
+ logging.info('All nodes disconnected.')
next_worker = None
- logging.warning(f'dgpu offline: {uid}')
def is_worker_busy(nid: str):
return nodes[nid]['task'] != None
@@ -90,8 +111,6 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
def get_next_worker():
nonlocal next_worker
- logging.info('get next_worker called')
- logging.info(f'pre next_worker: {next_worker}')
if next_worker == None:
raise SkynetDGPUOffline('No workers connected, try again later')
@@ -113,392 +132,79 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
if next_worker >= len(nodes):
next_worker = 0
- logging.info(f'post next_worker: {next_worker}')
-
return nid
- async def dgpu_heartbeat_service():
- nonlocal heartbeats
- while True:
- await trio.sleep(60)
- rid = uuid.uuid4().hex
- beat_msg = DGPUBusMessage(
- rid=rid,
- nid='',
- method='heartbeat'
- )
- heartbeats.clear()
- heartbeats[rid] = int(time.time() * 1000)
- await dgpu_bus.asend(beat_msg.SerializeToString())
- logging.info('sent heartbeat')
-
- async def dgpu_bus_streamer():
- nonlocal wip_reqs, fin_reqs, heartbeats
- while True:
- raw_msg = await dgpu_bus.arecv()
- logging.info(f'streamer got {len(raw_msg)} bytes.')
- msg = DGPUBusMessage()
- msg.ParseFromString(raw_msg)
-
- if security:
- verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
-
- rid = msg.rid
-
- if msg.method == 'heartbeat':
- sent_time = heartbeats[rid]
- delta = msg.params['time'] - sent_time
- logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
- continue
-
- if rid not in wip_reqs:
- continue
-
- if msg.method == 'binary-reply':
- logging.info('bin reply, recv extra data')
- raw_img = await dgpu_bus.arecv()
- msg = (msg, raw_img)
-
- fin_reqs[rid] = msg
- event = wip_reqs[rid]
- event.set()
- del wip_reqs[rid]
-
- async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
- nonlocal wip_reqs, fin_reqs, next_worker
- nid = get_next_worker()
- idx = list(nodes.keys()).index(nid)
- logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}')
- rid = uuid.uuid4().hex
- ack_event = trio.Event()
- img_event = trio.Event()
- wip_reqs[rid] = ack_event
-
- nodes[nid]['task'] = rid
-
- dgpu_req = DGPUBusMessage(
- rid=rid,
- nid=nid,
- method='diffuse')
- dgpu_req.params.update(req.to_dict())
-
- if security:
- dgpu_req.auth.cert = 'skynet'
- dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
-
- msg = dgpu_req.SerializeToString()
- if img_buf:
- logging.info(f'sending img of size {len(img_buf)} as attachment')
- logging.info(img_buf[:10])
- msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
-
- await dgpu_bus.asend(msg)
-
- with trio.move_on_after(4):
- await ack_event.wait()
-
- logging.info(f'ack event: {ack_event.is_set()}')
-
- if not ack_event.is_set():
- disconnect_node(nid)
- raise SkynetDGPUOffline('dgpu failed to acknowledge request')
-
- ack_msg = fin_reqs[rid]
- if 'ack' not in ack_msg.params:
- disconnect_node(nid)
- raise SkynetDGPUOffline('dgpu failed to acknowledge request')
-
- wip_reqs[rid] = img_event
- with trio.move_on_after(30):
- await img_event.wait()
-
- logging.info(f'img event: {ack_event.is_set()}')
-
- if not img_event.is_set():
- disconnect_node(nid)
- raise SkynetDGPUComputeError('30 seconds timeout while processing request')
-
- nodes[nid]['task'] = None
-
- resp = fin_reqs[rid]
- del fin_reqs[rid]
- if isinstance(resp, tuple):
- meta, img = resp
- return rid, img, meta.params
-
- raise SkynetDGPUComputeError(MessageToDict(resp.params))
-
-
- async def handle_user_request(rpc_ctx, req):
- try:
- async with db_pool.acquire() as conn:
- user = await get_or_create_user(conn, req.uid)
-
- result = {}
-
- match req.method:
- case 'txt2img':
- logging.info('txt2img')
- user_config = {**(await get_user_config(conn, user))}
- del user_config['id']
- user_config.update(MessageToDict(req.params))
-
- req = DiffusionParameters(**user_config, image=False)
- rid, img, meta = await dgpu_stream_one_img(req)
- logging.info(f'done streaming {rid}')
- result = {
- 'id': rid,
- 'img': img.hex(),
- 'meta': meta
- }
-
- await update_user_stats(conn, user, last_prompt=user_config['prompt'])
- logging.info('updated user stats.')
-
- case 'img2img':
- logging.info('img2img')
- user_config = {**(await get_user_config(conn, user))}
- del user_config['id']
-
- params = MessageToDict(req.params)
- img_buf = bytes.fromhex(params['img'])
- del params['img']
- user_config.update(params)
-
- req = DiffusionParameters(**user_config, image=True)
-
- if not req.image:
- raise AssertionError('Didn\'t enable image flag for img2img?')
-
- rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
- logging.info(f'done streaming {rid}')
- result = {
- 'id': rid,
- 'img': img.hex(),
- 'meta': meta
- }
-
- await update_user_stats(conn, user, last_prompt=user_config['prompt'])
- logging.info('updated user stats.')
-
- case 'redo':
- logging.info('redo')
- user_config = {**(await get_user_config(conn, user))}
- del user_config['id']
- prompt = await get_last_prompt_of(conn, user)
-
- if prompt:
- req = DiffusionParameters(
- prompt=prompt,
- **user_config,
- image=False
- )
- rid, img, meta = await dgpu_stream_one_img(req)
- result = {
- 'id': rid,
- 'img': img.hex(),
- 'meta': meta
- }
- await update_user_stats(conn, user)
- logging.info('updated user stats.')
-
- else:
- result = {
- 'error': 'skynet_no_last_prompt',
- 'message': 'No prompt to redo, do txt2img first'
- }
-
- case 'config':
- logging.info('config')
- if req.params['attr'] in CONFIG_ATTRS:
- logging.info(f'update: {req.params}')
- await update_user_config(
- conn, user, req.params['attr'], req.params['val'])
- logging.info('done')
-
- else:
- logging.warning(f'{req.params["attr"]} not in {CONFIG_ATTRS}')
-
- case 'stats':
- logging.info('stats')
- generated, joined, role = await get_user_stats(conn, user)
-
- result = {
- 'generated': generated,
- 'joined': joined.strftime(DATE_FORMAT),
- 'role': role
- }
-
- case _:
- logging.warn('unknown method')
-
- except SkynetDGPUOffline as e:
- result = {
- 'error': 'skynet_dgpu_offline',
- 'message': str(e)
- }
-
- except SkynetDGPUOverloaded as e:
- result = {
- 'error': 'skynet_dgpu_overloaded',
- 'message': str(e),
- 'nodes': len(nodes)
- }
-
- except SkynetDGPUComputeError as e:
- result = {
- 'error': 'skynet_dgpu_compute_error',
- 'message': str(e)
- }
- except BaseException as e:
- traceback.print_exception(type(e), e, e.__traceback__)
- result = {
- 'error': 'skynet_internal_error',
- 'message': str(e)
- }
-
+ async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
+ result = {'ok': {}}
resp = SkynetRPCResponse()
- resp.result.update(result)
-
- if security:
- resp.auth.cert = 'skynet'
- resp.auth.sig = sign_protobuf_msg(resp, tls_key)
-
- logging.info('sending response')
- await rpc_ctx.asend(resp.SerializeToString())
- rpc_ctx.close()
- logging.info('done')
-
- async def request_service(n):
- nonlocal next_worker
- while True:
- ctx = sock.new_context()
- req = SkynetRPCRequest()
- req.ParseFromString(await ctx.arecv())
-
- if security:
- if req.auth.cert not in tls_whitelist:
- logging.warning(
- f'{req.cert} not in tls whitelist and security=True')
- continue
-
- try:
- verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
-
- except ValueError:
- logging.warning(
- f'{req.cert} sent an unauthenticated msg with security=True')
- continue
-
- result = {}
+ try:
match req.method:
- case 'skynet_shutdown':
- raise SkynetShutdownRequested
-
case 'dgpu_online':
- connect_node(req.uid)
+ connect_node(req)
+
+ case 'dgpu_call':
+ nid = get_next_worker()
+ idx = list(nodes.keys()).index(nid)
+ node = nodes[nid]
+ logging.info(f'dgpu_call {idx}/{len(nodes)} {nid} @ {node["dgpu_addr"]}')
+ dgpu_time = await node['session'].rpc('dgpu_time')
+ if 'ok' not in dgpu_time.result:
+ status = MessageToDict(dgpu_time.result)
+ logging.warning(json.dumps(status, indent=4))
+ disconnect_node(nid)
+ raise SkynetDGPUComputeError(status['error'])
+
+ dgpu_time = dgpu_time.result['ok']
+ logging.info(f'ping to {nid}: {time_ms() - dgpu_time} ms')
+
+ try:
+ dgpu_result = await node['session'].rpc(
+ timeout=45, # give this 45 sec to run cause its compute
+ binext=req.bin,
+ **req.params
+ )
+ result = MessageToDict(dgpu_result.result)
+
+ if dgpu_result.bin:
+ resp.bin = dgpu_result.bin
+
+ except trio.TooSlowError:
+ result = {'error': 'timeout while processing request'}
case 'dgpu_offline':
disconnect_node(req.uid)
case 'dgpu_workers':
- result = len(nodes)
+ result = {'ok': len(nodes)}
case 'dgpu_next':
- result = next_worker
+ result = {'ok': next_worker}
- case 'heartbeat':
- logging.info('beat')
- result = {'time': time.time()}
+ case 'skynet_shutdown':
+ raise SkynetShutdownRequested
case _:
- n.start_soon(
- handle_user_request, ctx, req)
- continue
+ logging.warning(f'Unknown method {req.method}')
+ result = {'error': 'unknown method'}
- resp = SkynetRPCResponse()
- resp.result.update({'ok': result})
+ except BaseException as e:
+ result = {'error': str(e)}
- if security:
- resp.auth.cert = 'skynet'
- resp.auth.sig = sign_protobuf_msg(resp, tls_key)
+ resp.result.update(result)
- await ctx.asend(resp.SerializeToString())
+ return resp
- ctx.close()
+ rpc_server = SessionServer(
+ rpc_address,
+ rpc_handler,
+ cert_name='brain.cert',
+ key_name='brain.key'
+ )
-
- async with trio.open_nursery() as n:
- n.start_soon(dgpu_bus_streamer)
- n.start_soon(dgpu_heartbeat_service)
- n.start_soon(request_service, n)
- logging.info('starting rpc service')
+ async with rpc_server.open():
+ logging.info('rpc server is up')
yield
- logging.info('stopping rpc service')
- n.cancel_scope.cancel()
+ logging.info('skynet is shuting down...')
-
-@acm
-async def run_skynet(
- db_user: str = DB_USER,
- db_pass: str = DB_PASS,
- db_host: str = DB_HOST,
- rpc_address: str = DEFAULT_RPC_ADDR,
- dgpu_address: str = DEFAULT_DGPU_ADDR,
- security: bool = True
-):
- logging.basicConfig(level=logging.INFO)
- logging.info('skynet is starting')
-
- tls_config = None
- if security:
- # load tls certs
- certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
-
- tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text()
- tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
-
- tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text()
- tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
-
- tls_whitelist = {}
- for cert_path in (certs_dir / 'whitelist').glob('*.cert'):
- tls_whitelist[cert_path.stem] = load_certificate(
- FILETYPE_PEM, cert_path.read_text())
-
- cert_start = tls_cert_data.index('\n') + 1
- logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...')
- logging.info(f'tls_whitelist len: {len(tls_whitelist)}')
-
- rpc_address = 'tls+' + rpc_address
- dgpu_address = 'tls+' + dgpu_address
- tls_config = TLSConfig(
- TLSConfig.MODE_SERVER,
- own_key_string=tls_key_data,
- own_cert_string=tls_cert_data)
-
- with (
- pynng.Rep0(recv_max_size=0) as rpc_sock,
- pynng.Bus0(recv_max_size=0) as dgpu_bus
- ):
- async with open_database_connection(
- db_user, db_pass, db_host) as db_pool:
-
- logging.info('connected to db.')
- if security:
- rpc_sock.tls_config = tls_config
- dgpu_bus.tls_config = tls_config
-
- rpc_sock.listen(rpc_address)
- dgpu_bus.listen(dgpu_address)
-
- try:
- async with open_rpc_service(
- rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
- yield
-
- except SkynetShutdownRequested:
- ...
-
- logging.info('disconnected from db.')
+ logging.info('skynet down.')
diff --git a/skynet/cli.py b/skynet/cli.py
index 2573106..021e1e6 100644
--- a/skynet/cli.py
+++ b/skynet/cli.py
@@ -17,8 +17,8 @@ if torch_enabled:
from .dgpu import open_dgpu_node
from .brain import run_skynet
+from .config import *
from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR
-
from .frontend.telegram import run_skynet_telegram
@@ -38,8 +38,8 @@ def skynet(*args, **kwargs):
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def txt2img(*args, **kwargs):
- assert 'HF_TOKEN' in os.environ
- utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
+ _, hf_token, _, cfg = init_env_from_config()
+ utils.txt2img(hf_token, **kwargs)
@click.command()
@click.option('--model', '-m', default='midj')
@@ -52,9 +52,9 @@ def txt2img(*args, **kwargs):
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
- assert 'HF_TOKEN' in os.environ
+ _, hf_token, _, cfg = init_env_from_config()
utils.img2img(
- os.environ['HF_TOKEN'],
+ hf_token,
model=model,
prompt=prompt,
img_path=input,
@@ -76,6 +76,12 @@ def upscale(input, output, model):
model_path=model)
+@skynet.command()
+def download():
+ _, hf_token, _, cfg = init_env_from_config()
+ utils.download_all_models(hf_token)
+
+
@skynet.group()
def run(*args, **kwargs):
pass
@@ -85,29 +91,17 @@ def run(*args, **kwargs):
@click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option(
'--host', '-H', default=DEFAULT_RPC_ADDR)
-@click.option(
- '--host-dgpu', '-D', default=DEFAULT_DGPU_ADDR)
-@click.option(
- '--db-host', '-h', default='localhost:5432')
-@click.option(
- '--db-pass', '-p', default='password')
def brain(
loglevel: str,
- host: str,
- host_dgpu: str,
- db_host: str,
- db_pass: str
+ host: str
):
async def _run_skynet():
async with run_skynet(
- db_host=db_host,
- db_pass=db_pass,
- rpc_address=host,
- dgpu_address=host_dgpu
+ rpc_address=host
):
await trio.sleep_forever()
- trio_asyncio.run(_run_skynet)
+ trio.run(_run_skynet)
@run.command()
@@ -115,9 +109,9 @@ def brain(
@click.option(
'--uid', '-u', required=True)
@click.option(
- '--key', '-k', default='dgpu')
+ '--key', '-k', default='dgpu.key')
@click.option(
- '--cert', '-c', default='whitelist/dgpu')
+ '--cert', '-c', default='whitelist/dgpu.cert')
@click.option(
'--algos', '-a', default=json.dumps(['midj']))
@click.option(
@@ -159,11 +153,11 @@ def telegram(
cert: str,
rpc: str
):
- assert 'TG_TOKEN' in os.environ
+ _, _, tg_token, cfg = init_env_from_config()
trio_asyncio.run(
partial(
run_skynet_telegram,
- os.environ['TG_TOKEN'],
+ tg_token,
key_name=key,
cert_name=cert,
rpc_address=rpc
diff --git a/skynet/config.py b/skynet/config.py
new file mode 100644
index 0000000..91d6101
--- /dev/null
+++ b/skynet/config.py
@@ -0,0 +1,39 @@
+#!/usr/bin/python
+
+import os
+
+from pathlib import Path
+from configparser import ConfigParser
+
+from .constants import DEFAULT_CONFIG_PATH
+
+
+def load_skynet_ini(
+ file_path=DEFAULT_CONFIG_PATH
+):
+ config = ConfigParser()
+ config.read(file_path)
+ return config
+
+
+def init_env_from_config(
+ file_path=DEFAULT_CONFIG_PATH
+):
+ config = load_skynet_ini()
+
+ if 'HF_TOKEN' in os.environ:
+ hf_token = os.environ['HF_TOKEN']
+ else:
+ hf_token = config['skynet.dgpu']['hf_token']
+
+ if 'HF_HOME' in os.environ:
+ hf_home = os.environ['HF_HOME']
+ else:
+ hf_home = config['skynet.dgpu']['hf_home']
+
+ if 'TG_TOKEN' in os.environ:
+ tg_token = os.environ['TG_TOKEN']
+ else:
+ tg_token = config['skynet.telegram']['token']
+
+ return hf_home, hf_token, tg_token, config
diff --git a/skynet/constants.py b/skynet/constants.py
index 1478269..3d96a2c 100644
--- a/skynet/constants.py
+++ b/skynet/constants.py
@@ -1,14 +1,9 @@
#!/usr/bin/python
-VERSION = '0.1a8'
+VERSION = '0.1a9'
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
-DB_HOST = 'localhost:5432'
-DB_USER = 'skynet'
-DB_PASS = 'password'
-DB_NAME = 'skynet'
-
ALGOS = {
'midj': 'prompthero/openjourney',
'stable': 'runwayml/stable-diffusion-v1-5',
@@ -118,6 +113,7 @@ DEFAULT_ALGO = 'midj'
DEFAULT_ROLE = 'pleb'
DEFAULT_UPSCALER = None
+DEFAULT_CONFIG_PATH = 'skynet.ini'
DEFAULT_CERTS_DIR = 'certs'
DEFAULT_CERT_WHITELIST_DIR = 'whitelist'
DEFAULT_CERT_SKYNET_PUB = 'brain.cert'
diff --git a/skynet/db/__init__.py b/skynet/db/__init__.py
new file mode 100644
index 0000000..fd45c9e
--- /dev/null
+++ b/skynet/db/__init__.py
@@ -0,0 +1,5 @@
+#!/usr/bin/python
+
+from .proxy import open_database_connection
+
+from .functions import open_new_database
diff --git a/skynet/db.py b/skynet/db/functions.py
similarity index 73%
rename from skynet/db.py
rename to skynet/db/functions.py
index fbcf202..10863c2 100644
--- a/skynet/db.py
+++ b/skynet/db/functions.py
@@ -1,18 +1,21 @@
#!/usr/bin/python
+import time
+import random
+import string
import logging
from typing import Optional
from datetime import datetime
-from contextlib import asynccontextmanager as acm
+from contextlib import contextmanager as cm
-import trio
-import triopg
-import trio_asyncio
+import docker
+import psycopg2
from asyncpg.exceptions import UndefinedColumnError
+from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
-from .constants import *
+from ..constants import *
DB_INIT_SQL = '''
@@ -75,29 +78,67 @@ def try_decode_uid(uid: str):
return None, None
-@acm
-async def open_database_connection(
- db_user: str = DB_USER,
- db_pass: str = DB_PASS,
- db_host: str = DB_HOST,
- db_name: str = DB_NAME
-):
- async with trio_asyncio.open_loop() as loop:
- async with triopg.create_pool(
- dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}'
- ) as pool_conn:
- async with pool_conn.acquire() as conn:
- res = await conn.execute(f'''
- select distinct table_schema
- from information_schema.tables
- where table_schema = \'{db_name}\'
- ''')
- if '1' in res:
- logging.info('schema already in db, skipping init')
- else:
- await conn.execute(DB_INIT_SQL)
+@cm
+def open_new_database():
+ rpassword = ''.join(
+ random.choice(string.ascii_lowercase)
+ for i in range(12))
+ password = ''.join(
+ random.choice(string.ascii_lowercase)
+ for i in range(12))
- yield pool_conn
+ dclient = docker.from_env()
+
+ container = dclient.containers.run(
+ 'postgres',
+ name='skynet-test-postgres',
+ ports={'5432/tcp': None},
+ environment={
+ 'POSTGRES_PASSWORD': rpassword
+ },
+ detach=True,
+ remove=True
+ )
+
+ for log in container.logs(stream=True):
+ log = log.decode().rstrip()
+ logging.info(log)
+ if ('database system is ready to accept connections' in log or
+ 'database system is shut down' in log):
+ break
+
+ # ip = container.attrs['NetworkSettings']['IPAddress']
+ container.reload()
+ port = container.ports['5432/tcp'][0]['HostPort']
+ host = f'localhost:{port}'
+
+ # why print the system is ready to accept connections when its not
+ # postgres? wtf
+ time.sleep(1)
+ logging.info('creating skynet db...')
+
+ conn = psycopg2.connect(
+ user='postgres',
+ password=rpassword,
+ host='localhost',
+ port=port
+ )
+ logging.info('connected...')
+ conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ f'CREATE USER skynet WITH PASSWORD \'{password}\'')
+ cursor.execute(
+ f'CREATE DATABASE skynet')
+ cursor.execute(
+ f'GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet')
+
+ conn.close()
+
+ logging.info('done.')
+ yield container, password, host
+
+ container.stop()
async def get_user(conn, uid: str):
diff --git a/skynet/db/proxy.py b/skynet/db/proxy.py
new file mode 100644
index 0000000..d2f86c1
--- /dev/null
+++ b/skynet/db/proxy.py
@@ -0,0 +1,123 @@
+#!/usr/bin/python
+
+import importlib
+
+from contextlib import asynccontextmanager as acm
+
+import trio
+import tractor
+import asyncpg
+import asyncio
+import trio_asyncio
+
+
+_spawn_kwargs = {
+ 'infect_asyncio': True,
+}
+
+
+async def aio_db_proxy(
+ to_trio: trio.MemorySendChannel,
+ from_trio: asyncio.Queue,
+ db_user: str = 'skynet',
+ db_pass: str = 'password',
+ db_host: str = 'localhost:5432',
+ db_name: str = 'skynet'
+) -> None:
+ db = importlib.import_module('skynet.db.functions')
+
+ pool = await asyncpg.create_pool(
+ dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}')
+
+ async with pool_conn.acquire() as conn:
+ res = await conn.execute(f'''
+ select distinct table_schema
+ from information_schema.tables
+ where table_schema = \'{db_name}\'
+ ''')
+ if '1' in res:
+ logging.info('schema already in db, skipping init')
+ else:
+ await conn.execute(DB_INIT_SQL)
+
+ # a first message must be sent **from** this ``asyncio``
+ # task or the ``trio`` side will never unblock from
+ # ``tractor.to_asyncio.open_channel_from():``
+ to_trio.send_nowait('start')
+
+ # XXX: this uses an ``from_trio: asyncio.Queue`` currently but we
+ # should probably offer something better.
+ while True:
+ msg = await from_trio.get()
+
+ method = getattr(db, msg.get('method'))
+ args = getattr(db, msg.get('args', []))
+ kwargs = getattr(db, msg.get('kwargs', {}))
+
+ async with pool_conn.acquire() as conn:
+ result = await method(conn, *args, **kwargs)
+ to_trio.send_nowait(result)
+
+
+@tractor.context
+async def trio_to_aio_db_proxy(
+ ctx: tractor.Context,
+ db_user: str = 'skynet',
+ db_pass: str = 'password',
+ db_host: str = 'localhost:5432',
+ db_name: str = 'skynet'
+):
+ # this will block until the ``asyncio`` task sends a "first"
+ # message.
+ async with tractor.to_asyncio.open_channel_from(
+ aio_db_proxy,
+ db_user=db_user,
+ db_pass=db_pass,
+ db_host=db_host,
+ db_name=db_name
+ ) as (first, chan):
+
+ assert first == 'start'
+ await ctx.started(first)
+
+ async with ctx.open_stream() as stream:
+
+ async for msg in stream:
+ await chan.send(msg)
+
+ out = await chan.receive()
+ # echo back to parent actor-task
+ await stream.send(out)
+
+
+@acm
+async def open_database_connection(
+ db_user: str = 'skynet',
+ db_pass: str = 'password',
+ db_host: str = 'localhost:5432',
+ db_name: str = 'skynet'
+):
+ async with tractor.open_nursery() as n:
+ p = await n.start_actor(
+ 'aio_db_proxy',
+ enable_modules=[__name__],
+ infect_asyncio=True,
+ )
+ async with p.open_context(
+ trio_to_aio_db_proxy,
+ db_user=db_user,
+ db_pass=db_pass,
+ db_host=db_host,
+ db_name=db_name
+ ) as (ctx, first):
+ async with ctx.open_stream() as stream:
+
+ async def _db_pc(method: str, *args, **kwargs):
+ await stream.send({
+ 'method': method,
+ 'args': args,
+ 'kwargs': kwargs
+ })
+ return await stream.receive()
+
+ yield _db_pc
diff --git a/skynet/dgpu.py b/skynet/dgpu.py
index 752c8b8..79c6c49 100644
--- a/skynet/dgpu.py
+++ b/skynet/dgpu.py
@@ -2,29 +2,17 @@
import gc
import io
-import trio
import json
-import uuid
-import time
-import zlib
import random
import logging
-import traceback
from PIL import Image
from typing import List, Optional
-from pathlib import Path
-from contextlib import ExitStack
-import pynng
+import trio
import torch
-from pynng import TLSConfig
-from OpenSSL.crypto import (
- load_privatekey,
- load_certificate,
- FILETYPE_PEM
-)
+from pynng import Context
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
@@ -34,12 +22,9 @@ from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers.models import UNet2DConditionModel
-from .utils import (
- pipeline_for,
- convert_from_cv2_to_image, convert_from_image_to_cv2
-)
+from .utils import *
+from .network import *
from .protobuf import *
-from .frontend import open_skynet_rpc
from .constants import *
@@ -64,65 +49,16 @@ class DGPUComputeError(BaseException):
...
-class ReconnectingBus:
-
- def __init__(self, address: str, tls_config: Optional[TLSConfig]):
- self.address = address
- self.tls_config = tls_config
-
- self._stack = ExitStack()
- self._sock = None
- self._closed = True
-
- def connect(self):
- self._sock = self._stack.enter_context(
- pynng.Bus0(recv_max_size=0))
- self._sock.tls_config = self.tls_config
- self._sock.dial(self.address)
- self._closed = False
-
- async def arecv(self):
- while True:
- try:
- return await self._sock.arecv()
-
- except pynng.exceptions.Closed:
- if self._closed:
- raise
-
- async def asend(self, msg):
- while True:
- try:
- return await self._sock.asend(msg)
-
- except pynng.exceptions.Closed:
- if self._closed:
- raise
-
- def close(self):
- self._stack.close()
- self._stack = ExitStack()
- self._closed = True
-
- def reconnect(self):
- self.close()
- self.connect()
-
-
async def open_dgpu_node(
cert_name: str,
unique_id: str,
key_name: Optional[str],
rpc_address: str = DEFAULT_RPC_ADDR,
dgpu_address: str = DEFAULT_DGPU_ADDR,
- initial_algos: Optional[List[str]] = None,
- security: bool = True
+ initial_algos: Optional[List[str]] = None
):
- logging.basicConfig(level=logging.INFO)
+ logging.basicConfig(level=logging.DEBUG)
logging.info(f'starting dgpu node!')
-
- name = uuid.uuid4()
-
logging.info(f'loading models...')
upscaler = init_upscaler()
@@ -141,241 +77,140 @@ async def open_dgpu_node(
logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary())
- async def gpu_compute_one(ireq: DiffusionParameters, image=None):
- algo = ireq.algo + 'img' if image else ireq.algo
- if algo not in models:
- least_used = list(models.keys())[0]
- for model in models:
- if models[least_used]['generated'] > models[model]['generated']:
- least_used = model
+ async def gpu_compute_one(method: str, params: dict, binext: Optional[bytes] = None):
+ match method:
+ case 'diffuse':
+ image = None
+ algo = params['algo']
+ if binext:
+ algo += 'img'
+ image = Image.open(io.BytesIO(binext))
+ w, h = image.size
+ logging.info(f'user sent img of size {image.size}')
- del models[least_used]
- gc.collect()
+ if w > 512 or h > 512:
+ image.thumbnail((512, 512))
+ logging.info(f'resized it to {image.size}')
- models[algo] = {
- 'pipe': pipeline_for(ireq.algo, image=True if image else False),
- 'generated': 0
- }
+ if algo not in models:
+ logging.info(f'{algo} not in loaded models, swapping...')
+ least_used = list(models.keys())[0]
+ for model in models:
+ if models[least_used]['generated'] > models[model]['generated']:
+ least_used = model
- _params = {}
- if ireq.image:
- _params['image'] = image
- _params['strength'] = ireq.strength
+ del models[least_used]
+ gc.collect()
- else:
- _params['width'] = int(ireq.width)
- _params['height'] = int(ireq.height)
+ models[algo] = {
+ 'pipe': pipeline_for(params['algo'], image=True if binext else False),
+ 'generated': 0
+ }
+ logging.info(f'swapping done.')
- try:
- image = models[algo]['pipe'](
- ireq.prompt,
- **_params,
- guidance_scale=ireq.guidance,
- num_inference_steps=int(ireq.step),
- generator=torch.Generator("cuda").manual_seed(ireq.seed)
- ).images[0]
+ _params = {}
+ logging.info(method)
+ logging.info(json.dumps(params, indent=4))
+ logging.info(f'binext: {len(binext) if binext else 0} bytes')
+ if binext:
+ _params['image'] = image
+ _params['strength'] = params['strength']
- if ireq.upscaler == 'x4':
- logging.info(f'size: {len(image.tobytes())}')
- logging.info('performing upscale...')
- input_img = image.convert('RGB')
- up_img, _ = upscaler.enhance(
- convert_from_image_to_cv2(input_img), outscale=4)
+ else:
+ _params['width'] = int(params['width'])
+ _params['height'] = int(params['height'])
- image = convert_from_cv2_to_image(up_img)
- logging.info('done')
+ try:
+ image = models[algo]['pipe'](
+ params['prompt'],
+ **_params,
+ guidance_scale=params['guidance'],
+ num_inference_steps=int(params['step']),
+ generator=torch.Generator("cuda").manual_seed(
+ int(params['seed']) if params['seed'] else random.randint(0, 2 ** 64)
+ )
+ ).images[0]
- img_byte_arr = io.BytesIO()
- image.save(img_byte_arr, format='PNG')
- raw_img = img_byte_arr.getvalue()
- logging.info(f'final img size {len(raw_img)} bytes.')
+ if params['upscaler'] == 'x4':
+ logging.info(f'size: {len(image.tobytes())}')
+ logging.info('performing upscale...')
+ input_img = image.convert('RGB')
+ up_img, _ = upscaler.enhance(
+ convert_from_image_to_cv2(input_img), outscale=4)
- return raw_img
+ image = convert_from_cv2_to_image(up_img)
+ logging.info('done')
- except BaseException as e:
- logging.error(e)
- raise DGPUComputeError(str(e))
+ img_byte_arr = io.BytesIO()
+ image.save(img_byte_arr, format='PNG')
+ raw_img = img_byte_arr.getvalue()
+ logging.info(f'final img size {len(raw_img)} bytes.')
- finally:
- torch.cuda.empty_cache()
+ return raw_img
+
+ except BaseException as e:
+ logging.error(e)
+ raise DGPUComputeError(str(e))
+
+ finally:
+ torch.cuda.empty_cache()
+
+ case _:
+ raise DGPUComputeError('Unsupported compute method')
+
+ async def rpc_handler(req: SkynetRPCRequest, ctx: Context):
+ result = {}
+ resp = SkynetRPCResponse()
+
+ match req.method:
+ case 'dgpu_time':
+ result = {'ok': time_ms()}
+
+ case _:
+ logging.debug(f'dgpu got one request: {req.method}')
+ try:
+ resp.bin = await gpu_compute_one(
+ req.method, MessageToDict(req.params),
+ binext=req.bin if req.bin else None
+ )
+ logging.debug(f'dgpu processed one request')
+
+ except DGPUComputeError as e:
+ result = {'error': str(e)}
+
+ resp.result.update(result)
+ return resp
+
+ rpc_server = SessionServer(
+ dgpu_address,
+ rpc_handler,
+ cert_name=cert_name,
+ key_name=key_name
+ )
+ skynet_rpc = SessionClient(
+ rpc_address,
+ unique_id,
+ cert_name=cert_name,
+ key_name=key_name
+ )
+ skynet_rpc.connect()
- async with (
- open_skynet_rpc(
- unique_id,
- rpc_address=rpc_address,
- security=security,
- cert_name=cert_name,
- key_name=key_name
- ) as rpc_call,
- trio.open_nursery() as n
- ):
+ async with rpc_server.open() as rpc_server:
+ res = await skynet_rpc.rpc(
+ 'dgpu_online', {
+ 'dgpu_addr': rpc_server.addr,
+ 'cert': cert_name
+ })
- tls_config = None
- if security:
- # load tls certs
- if not key_name:
- key_name = cert_name
-
- certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
-
- skynet_cert_path = certs_dir / 'brain.cert'
- tls_cert_path = certs_dir / f'{cert_name}.cert'
- tls_key_path = certs_dir / f'{key_name}.key'
-
- cert_name = tls_cert_path.stem
-
- skynet_cert_data = skynet_cert_path.read_text()
- skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
-
- tls_cert_data = tls_cert_path.read_text()
-
- tls_key_data = tls_key_path.read_text()
- tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
-
- logging.info(f'skynet cert: {skynet_cert_path}')
- logging.info(f'dgpu cert: {tls_cert_path}')
- logging.info(f'dgpu key: {tls_key_path}')
-
- dgpu_address = 'tls+' + dgpu_address
- tls_config = TLSConfig(
- TLSConfig.MODE_CLIENT,
- own_key_string=tls_key_data,
- own_cert_string=tls_cert_data,
- ca_string=skynet_cert_data)
-
- logging.info(f'connecting to {dgpu_address}')
-
- dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
- dgpu_bus.connect()
-
- last_msg = time.time()
- async def connection_refresher(refresh_time: int = 120):
- nonlocal last_msg
- while True:
- now = time.time()
- last_msg_time_delta = now - last_msg
- logging.info(f'time since last msg: {last_msg_time_delta}')
- if last_msg_time_delta > refresh_time:
- dgpu_bus.reconnect()
- logging.info('reconnected!')
- last_msg = now
-
- await trio.sleep(refresh_time)
-
- n.start_soon(connection_refresher)
-
- res = await rpc_call('dgpu_online')
assert 'ok' in res.result
try:
- while True:
- msg = await dgpu_bus.arecv()
-
- img = None
- if b'BINEXT' in msg:
- header, msg, img_raw = msg.split(b'%$%$')
- logging.info(f'got img attachment of size {len(img_raw)}')
- logging.info(img_raw[:10])
- raw_img = zlib.decompress(img_raw)
- logging.info(raw_img[:10])
- img = Image.open(io.BytesIO(raw_img))
- w, h = img.size
- logging.info(f'user sent img of size {img.size}')
-
- if w > 512 or h > 512:
- img.thumbnail((512, 512))
- logging.info(f'resized it to {img.size}')
-
-
- req = DGPUBusMessage()
- req.ParseFromString(msg)
- last_msg = time.time()
-
- if req.method == 'heartbeat':
- rep = DGPUBusMessage(
- rid=req.rid,
- nid=unique_id,
- method=req.method
- )
- rep.params.update({'time': int(time.time() * 1000)})
-
- if security:
- rep.auth.cert = cert_name
- rep.auth.sig = sign_protobuf_msg(rep, tls_key)
-
- await dgpu_bus.asend(rep.SerializeToString())
- logging.info('heartbeat reply')
- continue
-
- if req.nid != unique_id:
- logging.info(
- f'witnessed msg {req.rid}, node involved: {req.nid}')
- continue
-
- if security:
- verify_protobuf_msg(req, skynet_cert)
-
-
- ack_resp = DGPUBusMessage(
- rid=req.rid,
- nid=req.nid
- )
- ack_resp.params.update({'ack': {}})
-
- if security:
- ack_resp.auth.cert = cert_name
- ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
-
- # send ack
- await dgpu_bus.asend(ack_resp.SerializeToString())
-
- logging.info(f'sent ack, processing {req.rid}...')
-
- try:
- img_req = DiffusionParameters(**req.params)
-
- if not img_req.seed:
- img_req.seed = random.randint(0, 2 ** 64)
-
- img = await gpu_compute_one(img_req, image=img)
- img_resp = DGPUBusMessage(
- rid=req.rid,
- nid=req.nid,
- method='binary-reply'
- )
- img_resp.params.update({
- 'len': len(img),
- 'meta': img_req.to_dict()
- })
-
- except DGPUComputeError as e:
- traceback.print_exception(type(e), e, e.__traceback__)
- img_resp = DGPUBusMessage(
- rid=req.rid,
- nid=req.nid
- )
- img_resp.params.update({'error': str(e)})
-
-
- if security:
- img_resp.auth.cert = cert_name
- img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
-
- # send final image
- logging.info('sending img back...')
- raw_msg = img_resp.SerializeToString()
- await dgpu_bus.asend(raw_msg)
- logging.info(f'sent {len(raw_msg)} bytes.')
- if img_resp.method == 'binary-reply':
- await dgpu_bus.asend(zlib.compress(img))
- logging.info(f'sent {len(img)} bytes.')
+ await trio.sleep_forever()
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
- n.cancel_scope.cancel()
- dgpu_bus.close()
finally:
- res = await rpc_call('dgpu_offline')
+ res = await skynet_rpc.rpc('dgpu_offline')
assert 'ok' in res.result
diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py
index f8193a2..04d6b90 100644
--- a/skynet/frontend/__init__.py
+++ b/skynet/frontend/__init__.py
@@ -4,7 +4,7 @@ import json
from typing import Union, Optional
from pathlib import Path
-from contextlib import asynccontextmanager as acm
+from contextlib import contextmanager as cm
import pynng
@@ -17,6 +17,7 @@ from OpenSSL.crypto import (
from google.protobuf.struct_pb2 import Struct
+from ..network import SessionClient
from ..constants import *
from ..protobuf.auth import *
@@ -39,75 +40,23 @@ class ConfigSizeDivisionByEight(BaseException):
...
-@acm
-async def open_skynet_rpc(
+@cm
+def open_skynet_rpc(
unique_id: str,
rpc_address: str = DEFAULT_RPC_ADDR,
- security: bool = False,
cert_name: Optional[str] = None,
key_name: Optional[str] = None
):
- tls_config = None
-
- if security:
- # load tls certs
- if not key_name:
- key_name = cert_name
-
- certs_dir = Path(DEFAULT_CERTS_DIR).resolve()
-
- skynet_cert_data = (certs_dir / 'brain.cert').read_text()
- skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data)
-
- tls_cert_path = certs_dir / f'{cert_name}.cert'
- tls_cert_data = tls_cert_path.read_text()
- tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data)
- cert_name = tls_cert_path.stem
-
- tls_key_data = (certs_dir / f'{key_name}.key').read_text()
- tls_key = load_privatekey(FILETYPE_PEM, tls_key_data)
-
- rpc_address = 'tls+' + rpc_address
- tls_config = TLSConfig(
- TLSConfig.MODE_CLIENT,
- own_key_string=tls_key_data,
- own_cert_string=tls_cert_data,
- ca_string=skynet_cert_data)
-
- with pynng.Req0(recv_max_size=0) as sock:
- if security:
- sock.tls_config = tls_config
-
- sock.dial(rpc_address)
-
- async def _rpc_call(
- method: str,
- params: dict = {},
- uid: Optional[str] = None
- ):
- req = SkynetRPCRequest()
- req.uid = uid if uid else unique_id
- req.method = method
- req.params.update(params)
-
- if security:
- req.auth.cert = cert_name
- req.auth.sig = sign_protobuf_msg(req, tls_key)
-
- ctx = sock.new_context()
- await ctx.asend(req.SerializeToString())
-
- resp = SkynetRPCResponse()
- resp.ParseFromString(await ctx.arecv())
- ctx.close()
-
- if security:
- verify_protobuf_msg(resp, skynet_cert)
-
- return resp
-
- yield _rpc_call
-
+ sesh = SessionClient(
+ rpc_address,
+ unique_id,
+ cert_name=cert_name,
+ key_name=key_name
+ )
+ logging.debug(f'opening skynet rpc...')
+ sesh.connect()
+ yield sesh
+ sesh.disconnect()
def validate_user_config_request(req: str):
params = req.split(' ')
diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py
index 3287b3a..65a6fcb 100644
--- a/skynet/frontend/telegram.py
+++ b/skynet/frontend/telegram.py
@@ -6,8 +6,6 @@ import logging
from datetime import datetime
-import pynng
-
from PIL import Image
from trio_asyncio import aio_as_trio
@@ -16,6 +14,7 @@ from telebot.types import (
)
from telebot.async_telebot import AsyncTeleBot
+from ..db import open_database_connection
from ..constants import *
from . import *
@@ -56,228 +55,274 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str:
async def run_skynet_telegram(
+ name: str,
tg_token: str,
- key_name: str = 'telegram-frontend',
- cert_name: str = 'whitelist/telegram-frontend',
- rpc_address: str = DEFAULT_RPC_ADDR
+ key_name: str = 'telegram-frontend.key',
+ cert_name: str = 'whitelist/telegram-frontend.cert',
+ rpc_address: str = DEFAULT_RPC_ADDR,
+ db_host: str = 'localhost:5432',
+ db_user: str = 'skynet',
+ db_pass: str = 'password'
):
logging.basicConfig(level=logging.INFO)
bot = AsyncTeleBot(tg_token)
+ logging.info(f'tg_token: {tg_token}')
- async with open_skynet_rpc(
- 'skynet-telegram-0',
- rpc_address=rpc_address,
- security=True,
- cert_name=cert_name,
- key_name=key_name
- ) as rpc_call:
+ async with open_database_connection(
+ db_user, db_pass, db_host
+ ) as db_call:
+ with open_skynet_rpc(
+ f'skynet-telegram-{name}',
+ rpc_address=rpc_address,
+ cert_name=cert_name,
+ key_name=key_name
+ ) as session:
- async def _rpc_call(
- uid: int,
- method: str,
- params: dict = {}
- ):
- return await rpc_call(
- method, params, uid=f'{PREFIX}+{uid}')
+ @bot.message_handler(commands=['help'])
+ async def send_help(message):
+ splt_msg = message.text.split(' ')
- @bot.message_handler(commands=['help'])
- async def send_help(message):
- splt_msg = message.text.split(' ')
-
- if len(splt_msg) == 1:
- await bot.reply_to(message, HELP_TEXT)
-
- else:
- param = splt_msg[1]
- if param in HELP_TOPICS:
- await bot.reply_to(message, HELP_TOPICS[param])
+ if len(splt_msg) == 1:
+ await bot.reply_to(message, HELP_TEXT)
else:
- await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
+ param = splt_msg[1]
+ if param in HELP_TOPICS:
+ await bot.reply_to(message, HELP_TOPICS[param])
- @bot.message_handler(commands=['cool'])
- async def send_cool_words(message):
- await bot.reply_to(message, '\n'.join(COOL_WORDS))
+ else:
+ await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
- @bot.message_handler(commands=['txt2img'])
- async def send_txt2img(message):
- chat = message.chat
+ @bot.message_handler(commands=['cool'])
+ async def send_cool_words(message):
+ await bot.reply_to(message, '\n'.join(COOL_WORDS))
- prompt = ' '.join(message.text.split(' ')[1:])
+ @bot.message_handler(commands=['txt2img'])
+ async def send_txt2img(message):
+ chat = message.chat
+ reply_id = None
+ if chat.type == 'group' and chat.id == GROUP_ID:
+ reply_id = message.message_id
- if len(prompt) == 0:
- await bot.reply_to(message, 'Empty text prompt ignored.')
- return
+ user_id = f'tg+{message.from_user.id}'
- logging.info(f'mid: {message.id}')
- resp = await _rpc_call(
- message.from_user.id,
- 'txt2img',
- {'prompt': prompt}
- )
- logging.info(f'resp to {message.id} arrived')
+ prompt = ' '.join(message.text.split(' ')[1:])
- resp_txt = ''
- result = MessageToDict(resp.result)
- if 'error' in resp.result:
- resp_txt = resp.result['message']
+ if len(prompt) == 0:
+ await bot.reply_to(message, 'Empty text prompt ignored.')
+ return
- else:
- logging.info(result['id'])
- img_raw = zlib.decompress(bytes.fromhex(result['img']))
- logging.info(f'got image of size: {len(img_raw)}')
- img = Image.open(io.BytesIO(img_raw))
+ logging.info(f'mid: {message.id}')
+ user = await db_call('get_or_create_user', user_id)
+ user_config = {**(await db_call('get_user_config', user))}
+ del user_config['id']
- await bot.send_photo(
- GROUP_ID,
- caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
- photo=img,
- reply_markup=build_redo_menu()
+ resp = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': prompt,
+ **user_config
+ }
+ },
+ timeout=60
)
- return
+ logging.info(f'resp to {message.id} arrived')
- await bot.reply_to(message, resp_txt)
+ resp_txt = ''
+ result = MessageToDict(resp.result)
+ if 'error' in resp.result:
+ resp_txt = resp.result['message']
+ await bot.reply_to(message, resp_txt)
- @bot.message_handler(func=lambda message: True, content_types=['photo'])
- async def send_img2img(message):
- chat = message.chat
+ else:
+ logging.info(result['id'])
+ img_raw = resp.bin
+ logging.info(f'got image of size: {len(img_raw)}')
+ img = Image.open(io.BytesIO(img_raw))
- if not message.caption.startswith('/img2img'):
- return
+ await bot.send_photo(
+ GROUP_ID,
+ caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
+ photo=img,
+ reply_to_message_id=reply_id,
+ reply_markup=build_redo_menu()
+ )
+ return
- prompt = ' '.join(message.caption.split(' ')[1:])
- if len(prompt) == 0:
- await bot.reply_to(message, 'Empty text prompt ignored.')
- return
+ @bot.message_handler(func=lambda message: True, content_types=['photo'])
+ async def send_img2img(message):
+ chat = message.chat
+ reply_id = None
+ if chat.type == 'group' and chat.id == GROUP_ID:
+ reply_id = message.message_id
- file_id = message.photo[-1].file_id
- file_path = (await bot.get_file(file_id)).file_path
- file_raw = await bot.download_file(file_path)
- img = zlib.compress(file_raw)
+ user_id = f'tg+{message.from_user.id}'
- logging.info(f'mid: {message.id}')
- resp = await _rpc_call(
- message.from_user.id,
- 'img2img',
- {'prompt': prompt, 'img': img.hex()}
- )
- logging.info(f'resp to {message.id} arrived')
+ if not message.caption.startswith('/img2img'):
+ await bot.reply_to(
+ message,
+ 'For image to image you need to add /img2img to the beggining of your caption'
+ )
+ return
- resp_txt = ''
- result = MessageToDict(resp.result)
- if 'error' in resp.result:
- resp_txt = resp.result['message']
+ prompt = ' '.join(message.caption.split(' ')[1:])
- else:
- logging.info(result['id'])
- img_raw = zlib.decompress(bytes.fromhex(result['img']))
- logging.info(f'got image of size: {len(img_raw)}')
- img = Image.open(io.BytesIO(img_raw))
+ if len(prompt) == 0:
+ await bot.reply_to(message, 'Empty text prompt ignored.')
+ return
- await bot.send_media_group(
- GROUP_ID,
- media=[
- InputMediaPhoto(file_id),
- InputMediaPhoto(
- img,
- caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
- )
- ]
+ file_id = message.photo[-1].file_id
+ file_path = (await bot.get_file(file_id)).file_path
+ file_raw = await bot.download_file(file_path)
+
+ logging.info(f'mid: {message.id}')
+
+ user = await db_call('get_or_create_user', user_id)
+ user_config = {**(await db_call('get_user_config', user))}
+ del user_config['id']
+
+ resp = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': prompt,
+ **user_config
+ }
+ },
+ binext=file_raw,
+ timeout=60
)
- return
+ logging.info(f'resp to {message.id} arrived')
- await bot.reply_to(message, resp_txt)
+ resp_txt = ''
+ result = MessageToDict(resp.result)
+ if 'error' in resp.result:
+ resp_txt = resp.result['message']
+ await bot.reply_to(message, resp_txt)
- @bot.message_handler(commands=['img2img'])
- async def redo_txt2img(message):
- await bot.reply_to(
- message,
- 'seems you tried to do an img2img command without sending image'
- )
+ else:
+ logging.info(result['id'])
+ img_raw = resp.bin
+ logging.info(f'got image of size: {len(img_raw)}')
+ img = Image.open(io.BytesIO(img_raw))
- async def _redo(message):
- resp = await _rpc_call(message.from_user.id, 'redo')
+ await bot.send_media_group(
+ GROUP_ID,
+ media=[
+ InputMediaPhoto(file_id),
+ InputMediaPhoto(
+ img,
+ caption=prepare_metainfo_caption(message.from_user, result['meta']['meta'])
+ )
+ ],
+ reply_to_message_id=reply_id
+ )
+ return
- resp_txt = ''
- result = MessageToDict(resp.result)
- if 'error' in resp.result:
- resp_txt = resp.result['message']
- else:
- logging.info(result['id'])
- img_raw = zlib.decompress(bytes.fromhex(result['img']))
- logging.info(f'got image of size: {len(img_raw)}')
- img = Image.open(io.BytesIO(img_raw))
-
- await bot.send_photo(
- GROUP_ID,
- caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
- photo=img,
- reply_markup=build_redo_menu()
+ @bot.message_handler(commands=['img2img'])
+ async def img2img_missing_image(message):
+ await bot.reply_to(
+ message,
+ 'seems you tried to do an img2img command without sending image'
)
- return
- await bot.reply_to(message, resp_txt)
+ @bot.message_handler(commands=['redo'])
+ async def redo(message):
+ chat = message.chat
+ reply_id = None
+ if chat.type == 'group' and chat.id == GROUP_ID:
+ reply_id = message.message_id
- @bot.message_handler(commands=['redo'])
- async def redo_txt2img(message):
- await _redo(message)
+ user_config = {**(await db_call('get_user_config', user))}
+ del user_config['id']
+ prompt = await db_call('get_last_prompt_of', user)
- @bot.message_handler(commands=['config'])
- async def set_config(message):
- rpc_params = {}
- try:
- attr, val, reply_txt = validate_user_config_request(
- message.text)
+ resp = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': prompt,
+ **user_config
+ }
+ },
+ timeout=60
+ )
+ logging.info(f'resp to {message.id} arrived')
- resp = await _rpc_call(
- message.from_user.id,
- 'config', {'attr': attr, 'val': val})
+ resp_txt = ''
+ result = MessageToDict(resp.result)
+ if 'error' in resp.result:
+ resp_txt = resp.result['message']
+ await bot.reply_to(message, resp_txt)
- except BaseException as e:
- reply_txt = str(e)
+ else:
+ logging.info(result['id'])
+ img_raw = resp.bin
+ logging.info(f'got image of size: {len(img_raw)}')
+ img = Image.open(io.BytesIO(img_raw))
- finally:
- await bot.reply_to(message, reply_txt)
+ await bot.send_photo(
+ GROUP_ID,
+ caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']),
+ photo=img,
+ reply_to_message_id=reply_id
+ )
+ return
- @bot.message_handler(commands=['stats'])
- async def user_stats(message):
- resp = await _rpc_call(
- message.from_user.id,
- 'stats',
- {}
- )
- stats = resp.result
+ @bot.message_handler(commands=['config'])
+ async def set_config(message):
+ rpc_params = {}
+ try:
+ attr, val, reply_txt = validate_user_config_request(
+ message.text)
- stats_str = f'generated: {stats["generated"]}\n'
- stats_str += f'joined: {stats["joined"]}\n'
- stats_str += f'role: {stats["role"]}\n'
+ logging.info(f'user config update: {attr} to {val}')
+ await db_call('update_user_config',
+ user, req.params['attr'], req.params['val'])
+ logging.info('done')
- await bot.reply_to(
- message, stats_str)
+ except BaseException as e:
+ reply_txt = str(e)
- @bot.message_handler(commands=['donate'])
- async def donation_info(message):
- await bot.reply_to(
- message, DONATION_INFO)
+ finally:
+ await bot.reply_to(message, reply_txt)
- @bot.message_handler(commands=['say'])
- async def say(message):
- chat = message.chat
- user = message.from_user
+ @bot.message_handler(commands=['stats'])
+ async def user_stats(message):
- if (chat.type == 'group') or (user.id != 383385940):
- return
+ generated, joined, role = await db_call('get_user_stats', user)
- await bot.send_message(GROUP_ID, message.text[4:])
+ stats_str = f'generated: {generated}\n'
+ stats_str += f'joined: {joined}\n'
+ stats_str += f'role: {role}\n'
+
+ await bot.reply_to(
+ message, stats_str)
+
+ @bot.message_handler(commands=['donate'])
+ async def donation_info(message):
+ await bot.reply_to(
+ message, DONATION_INFO)
+
+ @bot.message_handler(commands=['say'])
+ async def say(message):
+ chat = message.chat
+ user = message.from_user
+
+ if (chat.type == 'group') or (user.id != 383385940):
+ return
+
+ await bot.send_message(GROUP_ID, message.text[4:])
- @bot.message_handler(func=lambda message: True)
- async def echo_message(message):
- if message.text[0] == '/':
- await bot.reply_to(message, UNKNOWN_CMD_TEXT)
+ @bot.message_handler(func=lambda message: True)
+ async def echo_message(message):
+ if message.text[0] == '/':
+ await bot.reply_to(message, UNKNOWN_CMD_TEXT)
@bot.callback_query_handler(func=lambda call: True)
async def callback_query(call):
@@ -289,4 +334,4 @@ async def run_skynet_telegram(
await _redo(call)
- await aio_as_trio(bot.infinity_polling())
+ await aio_as_trio(bot.infinity_polling)()
diff --git a/skynet/network.py b/skynet/network.py
new file mode 100644
index 0000000..95fb60f
--- /dev/null
+++ b/skynet/network.py
@@ -0,0 +1,341 @@
+#!/usr/bin/python
+
+import zlib
+import socket
+
+from typing import Callable, Awaitable, Optional
+from pathlib import Path
+from contextlib import asynccontextmanager as acm
+from cryptography import x509
+from cryptography.hazmat.primitives import serialization
+
+import trio
+import pynng
+
+from pynng import TLSConfig, Context
+
+from .protobuf import *
+from .constants import *
+
+
+def get_random_port():
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(('', 0))
+ return s.getsockname()[1]
+
+
+def load_certs(
+ certs_dir: str,
+ cert_name: str,
+ key_name: str
+):
+ certs_dir = Path(certs_dir).resolve()
+ tls_key_data = (certs_dir / key_name).read_bytes()
+ tls_key = serialization.load_pem_private_key(
+ tls_key_data,
+ password=None
+ )
+
+ tls_cert_data = (certs_dir / cert_name).read_bytes()
+ tls_cert = x509.load_pem_x509_certificate(
+ tls_cert_data
+ )
+
+ tls_whitelist = {}
+ for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
+ tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
+ cert_path.read_bytes()
+ )
+
+ return (
+ SessionTLSConfig(
+ TLSConfig.MODE_SERVER,
+ own_key_string=tls_key_data,
+ own_cert_string=tls_cert_data
+ ),
+
+ tls_whitelist
+ )
+
+
+def load_certs_client(
+ certs_dir: str,
+ cert_name: str,
+ key_name: str,
+ ca_name: Optional[str] = None
+):
+ certs_dir = Path(certs_dir).resolve()
+ if not ca_name:
+ ca_name = 'brain.cert'
+
+ ca_cert_data = (certs_dir / ca_name).read_bytes()
+
+ tls_key_data = (certs_dir / key_name).read_bytes()
+
+
+ tls_cert_data = (certs_dir / cert_name).read_bytes()
+
+
+ tls_whitelist = {}
+ for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'):
+ tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate(
+ cert_path.read_bytes()
+ )
+
+ return (
+ SessionTLSConfig(
+ TLSConfig.MODE_CLIENT,
+ own_key_string=tls_key_data,
+ own_cert_string=tls_cert_data,
+ ca_string=ca_cert_data
+ ),
+
+ tls_whitelist
+ )
+
+
+class SessionError(BaseException):
+ ...
+
+
+class SessionTLSConfig(TLSConfig):
+
+ def __init__(
+ self,
+ mode,
+ server_name=None,
+ ca_string=None,
+ own_key_string=None,
+ own_cert_string=None,
+ auth_mode=None,
+ ca_files=None,
+ cert_key_file=None,
+ passwd=None
+ ):
+ super().__init__(
+ mode,
+ server_name=server_name,
+ ca_string=ca_string,
+ own_key_string=own_key_string,
+ own_cert_string=own_cert_string,
+ auth_mode=auth_mode,
+ ca_files=ca_files,
+ cert_key_file=cert_key_file,
+ passwd=passwd
+ )
+
+ if ca_string:
+ self.ca_cert = x509.load_pem_x509_certificate(ca_string)
+
+ self.cert = x509.load_pem_x509_certificate(own_cert_string)
+ self.key = serialization.load_pem_private_key(
+ own_key_string,
+ password=passwd
+ )
+
+
+class SessionServer:
+
+ def __init__(
+ self,
+ addr: str,
+ msg_handler: Callable[
+ [SkynetRPCRequest, Context], Awaitable[SkynetRPCResponse]
+ ],
+ cert_name: Optional[str] = None,
+ key_name: Optional[str] = None,
+ cert_dir: str = DEFAULT_CERTS_DIR,
+ recv_max_size = 0
+ ):
+ self.addr = addr
+ self.msg_handler = msg_handler
+
+ self.cert_name = cert_name
+ self.tls_config = None
+ self.tls_whitelist = None
+ if cert_name and key_name:
+ self.cert_name = cert_name
+ self.tls_config, self.tls_whitelist = load_certs(
+ cert_dir, cert_name, key_name)
+
+ self.addr = 'tls+' + self.addr
+
+ self.recv_max_size = recv_max_size
+
+ async def _handle_msg(self, req: SkynetRPCRequest, ctx: Context):
+ resp = await self.msg_handler(req, ctx)
+
+ if self.tls_config:
+ resp.auth.cert = 'skynet'
+ resp.auth.sig = sign_protobuf_msg(
+ resp, self.tls_config.key)
+
+ raw_msg = zlib.compress(resp.SerializeToString())
+
+ await ctx.asend(raw_msg)
+
+ ctx.close()
+
+ async def _listener (self, sock):
+ async with trio.open_nursery() as n:
+ while True:
+ ctx = sock.new_context()
+
+ raw_msg = await ctx.arecv()
+ raw_size = len(raw_msg)
+ logging.debug(f'rpc server new msg {raw_size} bytes')
+
+ try:
+ msg = zlib.decompress(raw_msg)
+ msg_size = len(msg)
+
+ except zlib.error:
+ logging.warning(f'Zlib decompress error, dropping msg of size {len(raw_msg)}')
+ continue
+
+ logging.debug(f'msg after decompress {msg_size} bytes, +{msg_size - raw_size} bytes')
+
+ req = SkynetRPCRequest()
+ try:
+ req.ParseFromString(msg)
+
+ except google.protobuf.message.DecodeError:
+ logging.warning(f'Dropping malfomed msg of size {len(msg)}')
+ continue
+
+ logging.debug(f'msg method: {req.method}')
+
+ if self.tls_config:
+ if req.auth.cert not in self.tls_whitelist:
+ logging.warning(
+ f'{req.auth.cert} not in tls whitelist')
+ continue
+
+ try:
+ verify_protobuf_msg(req, self.tls_whitelist[req.auth.cert])
+
+ except ValueError:
+ logging.warning(
+ f'{req.cert} sent an unauthenticated msg')
+ continue
+
+ n.start_soon(self._handle_msg, req, ctx)
+
+ @acm
+ async def open(self):
+ with pynng.Rep0(
+ recv_max_size=self.recv_max_size
+ ) as sock:
+
+ if self.tls_config:
+ sock.tls_config = self.tls_config
+
+ sock.listen(self.addr)
+
+ logging.debug(f'server socket listening at {self.addr}')
+
+ async with trio.open_nursery() as n:
+ n.start_soon(self._listener, sock)
+
+ try:
+ yield self
+
+ finally:
+ n.cancel_scope.cancel()
+
+ logging.debug('server socket is off.')
+
+
+class SessionClient:
+
+ def __init__(
+ self,
+ connect_addr: str,
+ uid: str,
+ cert_name: Optional[str] = None,
+ key_name: Optional[str] = None,
+ ca_name: Optional[str] = None,
+ cert_dir: str = DEFAULT_CERTS_DIR,
+ recv_max_size = 0
+ ):
+ self.uid = uid
+ self.connect_addr = connect_addr
+
+ self.cert_name = None
+ self.tls_config = None
+ self.tls_whitelist = None
+ self.tls_cert = None
+ self.tls_key = None
+ if cert_name and key_name:
+ self.cert_name = Path(cert_name).stem
+ self.tls_config, self.tls_whitelist = load_certs_client(
+ cert_dir, cert_name, key_name, ca_name=ca_name)
+
+ if not self.connect_addr.startswith('tls'):
+ self.connect_addr = 'tls+' + self.connect_addr
+
+ self.recv_max_size = recv_max_size
+
+ self._connected = False
+ self._sock = None
+
+ def connect(self):
+ self._sock = pynng.Req0(
+ recv_max_size=0,
+ name=self.uid
+ )
+
+ if self.tls_config:
+ self._sock.tls_config = self.tls_config
+
+ logging.debug(f'client is dialing {self.connect_addr}...')
+ self._sock.dial(self.connect_addr, block=True)
+ self._connected = True
+ logging.debug(f'client is connected to {self.connect_addr}')
+
+ def disconnect(self):
+ self._sock.close()
+ self._connected = False
+ logging.debug(f'client disconnected.')
+
+ async def rpc(
+ self,
+ method: str,
+ params: dict = {},
+ binext: Optional[bytes] = None,
+ timeout: float = 2.
+ ):
+ if not self._connected:
+ raise SessionError('tried to use rpc without connecting')
+
+ req = SkynetRPCRequest()
+ req.uid = self.uid
+ req.method = method
+ req.params.update(params)
+ if binext:
+ logging.debug('added binary extension')
+ req.bin = binext
+
+ if self.tls_config:
+ req.auth.cert = self.cert_name
+ req.auth.sig = sign_protobuf_msg(req, self.tls_config.key)
+
+ with trio.fail_after(timeout):
+ ctx = self._sock.new_context()
+ raw_req = zlib.compress(req.SerializeToString())
+ logging.debug(f'rpc client sending new msg {method} of size {len(raw_req)}')
+ await ctx.asend(raw_req)
+ logging.debug('sent, awaiting response...')
+ raw_resp = await ctx.arecv()
+ logging.debug(f'rpc client got response of size {len(raw_resp)}')
+ raw_resp = zlib.decompress(raw_resp)
+
+ resp = SkynetRPCResponse()
+ resp.ParseFromString(raw_resp)
+ ctx.close()
+
+ if self.tls_config:
+ verify_protobuf_msg(resp, self.tls_config.ca_cert)
+
+ return resp
diff --git a/skynet/protobuf/__init__.py b/skynet/protobuf/__init__.py
index b985940..acafec8 100644
--- a/skynet/protobuf/__init__.py
+++ b/skynet/protobuf/__init__.py
@@ -1,29 +1,4 @@
#!/usr/bin/python
-from typing import Optional
-from dataclasses import dataclass, asdict
-
-from google.protobuf.json_format import MessageToDict
-
from .auth import *
from .skynet_pb2 import *
-
-
-class Struct:
-
- def to_dict(self):
- return asdict(self)
-
-
-@dataclass
-class DiffusionParameters(Struct):
- algo: str
- prompt: str
- step: int
- width: int
- height: int
- guidance: float
- strength: float
- seed: Optional[int]
- image: bool # if true indicates a bytestream is next msg
- upscaler: Optional[str]
diff --git a/skynet/protobuf/auth.py b/skynet/protobuf/auth.py
index e2904cb..876683d 100644
--- a/skynet/protobuf/auth.py
+++ b/skynet/protobuf/auth.py
@@ -7,7 +7,8 @@ from hashlib import sha256
from collections import OrderedDict
from google.protobuf.json_format import MessageToDict
-from OpenSSL.crypto import PKey, X509, verify, sign
+from cryptography.hazmat.primitives import serialization, hashes
+from cryptography.hazmat.primitives.asymmetric import padding
from .skynet_pb2 import *
@@ -46,20 +47,23 @@ def serialize_msg_deterministic(msg):
if field_descriptor.message_type.name == 'Struct':
hash_dict(MessageToDict(getattr(msg, field_name)))
- deterministic_msg = shasum.hexdigest()
+ deterministic_msg = shasum.digest()
return deterministic_msg
-def sign_protobuf_msg(msg, key: PKey):
- return sign(
- key, serialize_msg_deterministic(msg), 'sha256').hex()
+def sign_protobuf_msg(msg, key):
+ return key.sign(
+ serialize_msg_deterministic(msg),
+ padding.PKCS1v15(),
+ hashes.SHA256()
+ ).hex()
-def verify_protobuf_msg(msg, cert: X509):
- return verify(
- cert,
+def verify_protobuf_msg(msg, cert):
+ return cert.public_key().verify(
bytes.fromhex(msg.auth.sig),
serialize_msg_deterministic(msg),
- 'sha256'
+ padding.PKCS1v15(),
+ hashes.SHA256()
)
diff --git a/skynet/protobuf/skynet.proto b/skynet/protobuf/skynet.proto
index 6e66274..0bdccad 100644
--- a/skynet/protobuf/skynet.proto
+++ b/skynet/protobuf/skynet.proto
@@ -13,18 +13,12 @@ message SkynetRPCRequest {
string uid = 1;
string method = 2;
google.protobuf.Struct params = 3;
- optional Auth auth = 4;
+ optional bytes bin = 4;
+ optional Auth auth = 5;
}
message SkynetRPCResponse {
google.protobuf.Struct result = 1;
- optional Auth auth = 2;
-}
-
-message DGPUBusMessage {
- string rid = 1;
- string nid = 2;
- string method = 3;
- google.protobuf.Struct params = 4;
- optional Auth auth = 5;
+ optional bytes bin = 2;
+ optional Auth auth = 3;
}
diff --git a/skynet/protobuf/skynet_pb2.py b/skynet/protobuf/skynet_pb2.py
index dd7db33..84b0527 100644
--- a/skynet/protobuf/skynet_pb2.py
+++ b/skynet/protobuf/skynet_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x9c\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_auth\"\x80\x01\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x03 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_authb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
@@ -24,9 +24,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_AUTH._serialized_start=54
_AUTH._serialized_end=87
_SKYNETRPCREQUEST._serialized_start=90
- _SKYNETRPCREQUEST._serialized_end=220
- _SKYNETRPCRESPONSE._serialized_start=222
- _SKYNETRPCRESPONSE._serialized_end=324
- _DGPUBUSMESSAGE._serialized_start=327
- _DGPUBUSMESSAGE._serialized_end=468
+ _SKYNETRPCREQUEST._serialized_end=246
+ _SKYNETRPCRESPONSE._serialized_start=249
+ _SKYNETRPCRESPONSE._serialized_end=377
# @@protoc_insertion_point(module_scope)
diff --git a/skynet/utils.py b/skynet/utils.py
index ba1ce2d..637078b 100644
--- a/skynet/utils.py
+++ b/skynet/utils.py
@@ -1,5 +1,6 @@
#!/usr/bin/python
+import time
import random
from typing import Optional
@@ -21,6 +22,10 @@ from huggingface_hub import login
from .constants import ALGOS
+def time_ms():
+ return int(time.time() * 1000)
+
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
return Image.fromarray(img)
@@ -164,3 +169,13 @@ def upscale(
image.save(output)
+
+
+def download_all_models(hf_token: str):
+ assert torch.cuda.is_available()
+
+ login(token=hf_token)
+ for model in ALGOS:
+ print(f'DOWNLOADING {model.upper()}')
+ pipeline_for(model)
+
diff --git a/tests/conftest.py b/tests/conftest.py
index 64a369f..0b4c335 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,89 +3,30 @@
import os
import json
import time
-import random
-import string
import logging
-from functools import partial
from pathlib import Path
+from functools import partial
-import trio
import pytest
-import psycopg2
-import trio_asyncio
from docker.types import Mount, DeviceRequest
-from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
-from skynet.constants import *
+from skynet.db import open_new_database
from skynet.brain import run_skynet
+from skynet.network import get_random_port
+from skynet.constants import *
@pytest.fixture(scope='session')
def postgres_db(dockerctl):
- rpassword = ''.join(
- random.choice(string.ascii_lowercase)
- for i in range(12))
- password = ''.join(
- random.choice(string.ascii_lowercase)
- for i in range(12))
-
- with dockerctl.run(
- 'postgres',
- name='skynet-test-postgres',
- ports={'5432/tcp': None},
- environment={
- 'POSTGRES_PASSWORD': rpassword
- }
- ) as containers:
- container = containers[0]
- # ip = container.attrs['NetworkSettings']['IPAddress']
- port = container.ports['5432/tcp'][0]['HostPort']
- host = f'localhost:{port}'
-
- for log in container.logs(stream=True):
- log = log.decode().rstrip()
- logging.info(log)
- if ('database system is ready to accept connections' in log or
- 'database system is shut down' in log):
- break
-
- # why print the system is ready to accept connections when its not
- # postgres? wtf
- time.sleep(1)
- logging.info('creating skynet db...')
-
- conn = psycopg2.connect(
- user='postgres',
- password=rpassword,
- host='localhost',
- port=port
- )
- logging.info('connected...')
- conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
- with conn.cursor() as cursor:
- cursor.execute(
- f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'')
- cursor.execute(
- f'CREATE DATABASE {DB_NAME}')
- cursor.execute(
- f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}')
-
- conn.close()
-
- logging.info('done.')
- yield container, password, host
+ with open_new_database() as db_params:
+ yield db_params
@pytest.fixture
-async def skynet_running(postgres_db):
- db_container, db_pass, db_host = postgres_db
-
- async with run_skynet(
- db_pass=db_pass,
- db_host=db_host
- ):
+async def skynet_running():
+ async with run_skynet():
yield
@@ -99,11 +40,13 @@ def dgpu_workers(request, dockerctl, skynet_running):
cmds = []
for i in range(num_containers):
+ dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
cmd = f'''
pip install -e . && \
skynet run dgpu \
--algos=\'{json.dumps(initial_algos)}\' \
- --uid=dgpu-{i}
+ --uid=dgpu-{i} \
+ --dgpu={dgpu_addr}
'''
cmds.append(['bash', '-c', cmd])
@@ -114,16 +57,15 @@ def dgpu_workers(request, dockerctl, skynet_running):
name='skynet-test-runtime-cuda',
commands=cmds,
environment={
- 'HF_TOKEN': os.environ['HF_TOKEN'],
'HF_HOME': '/skynet/hf_home'
},
network='host',
mounts=mounts,
device_requests=devices,
- num=num_containers
+ num=num_containers,
) as containers:
yield containers
- #for i, container in enumerate(containers):
- # logging.info(f'container {i} logs:')
- # logging.info(container.logs().decode())
+ for i, container in enumerate(containers):
+ logging.info(f'container {i} logs:')
+ logging.info(container.logs().decode())
diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py
index 4ce93bf..c187af0 100644
--- a/tests/test_dgpu.py
+++ b/tests/test_dgpu.py
@@ -12,29 +12,26 @@ from functools import partial
import trio
import pytest
-import trio_asyncio
from PIL import Image
from google.protobuf.json_format import MessageToDict
from skynet.brain import SkynetDGPUComputeError
-from skynet.constants import *
+from skynet.network import get_random_port, SessionServer
+from skynet.protobuf import SkynetRPCResponse
from skynet.frontend import open_skynet_rpc
+from skynet.constants import *
-async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
+async def wait_for_dgpus(session, amount: int, timeout: float = 30.0):
gpu_ready = False
- start_time = time.time()
- current_time = time.time()
- while not gpu_ready and (current_time - start_time) < timeout:
- res = await rpc('dgpu_workers')
- if res.result['ok'] >= amount:
- break
+ with trio.fail_after(timeout):
+ while not gpu_ready:
+ res = await session.rpc('dgpu_workers')
+ if res.result['ok'] >= amount:
+ break
- await trio.sleep(1)
- current_time = time.time()
-
- assert (current_time - start_time) < timeout
+ await trio.sleep(1)
_images = set()
@@ -48,34 +45,33 @@ async def check_request_img(
):
global _images
- async with open_skynet_rpc(
+ with open_skynet_rpc(
uid,
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as rpc_call:
- res = await rpc_call(
- 'txt2img', {
- 'prompt': 'red old tractor in a sunny wheat field',
- 'step': 28,
- 'width': width, 'height': height,
- 'guidance': 7.5,
- 'seed': None,
- 'algo': list(ALGOS.keys())[i],
- 'upscaler': upscaler
- })
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ res = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': 'red old tractor in a sunny wheat field',
+ 'step': 28,
+ 'width': width, 'height': height,
+ 'guidance': 7.5,
+ 'seed': None,
+ 'algo': list(ALGOS.keys())[i],
+ 'upscaler': upscaler
+ }
+ },
+ timeout=60
+ )
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
- if upscaler == 'x4':
- width *= 4
- height *= 4
-
- img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
+ img_raw = res.bin
img_sha = sha256(img_raw).hexdigest()
- img = Image.frombytes(
- 'RGB', (width, height), img_raw)
+ img = Image.open(io.BytesIO(img_raw))
if expect_unique and img_sha in _images:
raise ValueError('Duplicated image sha: {img_sha}')
@@ -96,13 +92,12 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
then generate a smaller image to show gpu worker recovery
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
with pytest.raises(SkynetDGPUComputeError) as e:
await check_request_img(0, width=4096, height=4096)
@@ -112,20 +107,35 @@ async def test_dgpu_worker_compute_error(dgpu_workers):
await check_request_img(0)
+@pytest.mark.parametrize(
+ 'dgpu_workers', [(1, ['midj'])], indirect=True)
+async def test_dgpu_worker(dgpu_workers):
+ '''Generate one image in a single dgpu worker
+ '''
+
+ with open_skynet_rpc(
+ 'test-ctx',
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
+
+ await check_request_img(0)
+
+
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True)
-async def test_dgpu_workers(dgpu_workers):
+async def test_dgpu_worker_two_models(dgpu_workers):
'''Generate two images in a single dgpu worker using
two different models.
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
await check_request_img(0)
await check_request_img(1)
@@ -138,14 +148,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
two different models.
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
- logging.error('UPSCALE')
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
img = await check_request_img(0, upscaler='x4')
@@ -157,13 +165,12 @@ async def test_dgpu_worker_upscale(dgpu_workers):
async def test_dgpu_workers_two(dgpu_workers):
'''Generate two images in two separate dgpu workers
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 2)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 2, timeout=60)
async with trio.open_nursery() as n:
n.start_soon(check_request_img, 0)
@@ -175,13 +182,12 @@ async def test_dgpu_workers_two(dgpu_workers):
async def test_dgpu_worker_algo_swap(dgpu_workers):
'''Generate an image using a non default model
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
await check_request_img(5)
@@ -191,33 +197,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers):
'''Connect three dgpu workers, disconnect and check next_worker
rotation happens correctly
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 3)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 3)
- res = await test_rpc('dgpu_next')
+ res = await session.rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 0
await check_request_img(0)
- res = await test_rpc('dgpu_next')
+ res = await session.rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 1
await check_request_img(0)
- res = await test_rpc('dgpu_next')
+ res = await session.rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 2
await check_request_img(0)
- res = await test_rpc('dgpu_next')
+ res = await session.rpc('dgpu_next')
assert 'ok' in res.result
assert res.result['ok'] == 0
@@ -228,13 +233,12 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
'''Connect three dgpu workers, disconnect the first one and check
next_worker rotation happens correctly
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 3)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 3)
await trio.sleep(3)
@@ -245,7 +249,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers):
dgpu_workers[0].wait()
- res = await test_rpc('dgpu_workers')
+ res = await session.rpc('dgpu_workers')
assert 'ok' in res.result
assert res.result['ok'] == 2
@@ -258,26 +262,43 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running):
'''Mock a node that connects, gets a request but fails to
acknowledge it, then check skynet correctly drops the node
'''
- async with open_skynet_rpc(
- 'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as rpc_call:
- res = await rpc_call('dgpu_online')
- assert 'ok' in res.result
+ async def mock_rpc(req, ctx):
+ resp = SkynetRPCResponse()
+ resp.result.update({'error': 'can\'t do it mate'})
+ return resp
- await wait_for_dgpus(rpc_call, 1)
+ dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}'
+ mock_server = SessionServer(
+ dgpu_addr,
+ mock_rpc,
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ )
- with pytest.raises(SkynetDGPUComputeError) as e:
- await check_request_img(0)
+ async with mock_server.open():
+ with open_skynet_rpc(
+ 'test-ctx',
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
- assert 'dgpu failed to acknowledge request' in str(e)
+ res = await session.rpc('dgpu_online', {
+ 'dgpu_addr': dgpu_addr,
+ 'cert': 'whitelist/testing.cert'
+ })
+ assert 'ok' in res.result
- res = await rpc_call('dgpu_workers')
- assert 'ok' in res.result
- assert res.result['ok'] == 0
+ await wait_for_dgpus(session, 1)
+
+ with pytest.raises(SkynetDGPUComputeError) as e:
+ await check_request_img(0)
+
+ assert 'can\'t do it mate' in str(e.value)
+
+ res = await session.rpc('dgpu_workers')
+ assert 'ok' in res.result
+ assert res.result['ok'] == 0
@pytest.mark.parametrize(
@@ -286,13 +307,12 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
'''Stop node while processing request to cause timeout and
then check skynet correctly drops the node.
'''
- async with open_skynet_rpc(
+ with open_skynet_rpc(
'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
async def check_request_img_raises():
with pytest.raises(SkynetDGPUComputeError) as e:
@@ -308,72 +328,62 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
assert ec == 0
-@pytest.mark.parametrize(
- 'dgpu_workers', [(1, ['midj'])], indirect=True)
-async def test_dgpu_heartbeat(dgpu_workers):
- '''
- '''
- async with open_skynet_rpc(
- 'test-ctx',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as test_rpc:
- await wait_for_dgpus(test_rpc, 1)
- await trio.sleep(120)
-
-
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_img2img(dgpu_workers):
- async with open_skynet_rpc(
- '1',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as rpc_call:
- await wait_for_dgpus(rpc_call, 1)
+ with open_skynet_rpc(
+ 'test-ctx',
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
+ await wait_for_dgpus(session, 1)
+ await trio.sleep(2)
- res = await rpc_call(
- 'txt2img', {
- 'prompt': 'red old tractor in a sunny wheat field',
- 'step': 28,
- 'width': 512, 'height': 512,
- 'guidance': 7.5,
- 'seed': None,
- 'algo': list(ALGOS.keys())[0],
- 'upscaler': None
- })
+ res = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': 'red old tractor in a sunny wheat field',
+ 'step': 28,
+ 'width': 512, 'height': 512,
+ 'guidance': 7.5,
+ 'seed': None,
+ 'algo': list(ALGOS.keys())[0],
+ 'upscaler': None
+ }
+ },
+ timeout=60
+ )
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
- img_raw = res.result['img']
- img = zlib.decompress(bytes.fromhex(img_raw))
- logging.info(img[:10])
- img = Image.open(io.BytesIO(img))
-
+ img_raw = res.bin
+ img = Image.open(io.BytesIO(img_raw))
img.save('txt2img.png')
- res = await rpc_call(
- 'img2img', {
- 'prompt': 'red sports car in a sunny wheat field',
- 'step': 28,
- 'img': img_raw,
- 'guidance': 12,
- 'seed': None,
- 'algo': list(ALGOS.keys())[0],
- 'upscaler': 'x4'
- })
+ res = await session.rpc(
+ 'dgpu_call', {
+ 'method': 'diffuse',
+ 'params': {
+ 'prompt': 'red ferrari in a sunny wheat field',
+ 'step': 28,
+ 'guidance': 8,
+ 'strength': 0.7,
+ 'seed': None,
+ 'algo': list(ALGOS.keys())[0],
+ 'upscaler': 'x4'
+ }
+ },
+ binext=img_raw,
+ timeout=60
+ )
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
- img_raw = res.result['img']
- img = zlib.decompress(bytes.fromhex(img_raw))
- logging.info(img[:10])
- img = Image.open(io.BytesIO(img))
-
+ img_raw = res.bin
+ img = Image.open(io.BytesIO(img_raw))
img.save('img2img.png')
diff --git a/tests/test_skynet.py b/tests/test_skynet.py
index 5572a70..1587d5d 100644
--- a/tests/test_skynet.py
+++ b/tests/test_skynet.py
@@ -9,6 +9,7 @@ import trio_asyncio
from skynet.brain import run_skynet
from skynet.structs import *
+from skynet.network import SessionServer
from skynet.frontend import open_skynet_rpc
@@ -18,53 +19,68 @@ async def test_skynet(skynet_running):
async def test_skynet_attempt_insecure(skynet_running):
with pytest.raises(pynng.exceptions.NNGException) as e:
- async with open_skynet_rpc('bad-actor'):
- ...
-
- assert str(e.value) == 'Connection shutdown'
+ with open_skynet_rpc('bad-actor') as session:
+ with trio.fail_after(5):
+ await session.rpc('skynet_shutdown')
async def test_skynet_dgpu_connection_simple(skynet_running):
- async with open_skynet_rpc(
+
+ async def rpc_handler(req, ctx):
+ ...
+
+ fake_dgpu_addr = 'tcp://127.0.0.1:41001'
+ rpc_server = SessionServer(
+ fake_dgpu_addr,
+ rpc_handler,
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ )
+
+ with open_skynet_rpc(
'dgpu-0',
- security=True,
- cert_name='whitelist/testing',
- key_name='testing'
- ) as rpc_call:
+ cert_name='whitelist/testing.cert',
+ key_name='testing.key'
+ ) as session:
# check 0 nodes are connected
- res = await rpc_call('dgpu_workers')
- assert 'ok' in res.result
+ res = await session.rpc('dgpu_workers')
+ assert 'ok' in res.result.keys()
assert res.result['ok'] == 0
# check next worker is None
- res = await rpc_call('dgpu_next')
- assert 'ok' in res.result
+ res = await session.rpc('dgpu_next')
+ assert 'ok' in res.result.keys()
assert res.result['ok'] == None
- # connect 1 dgpu
- res = await rpc_call('dgpu_online')
- assert 'ok' in res.result
+ async with rpc_server.open() as rpc_server:
+ # connect 1 dgpu
+ res = await session.rpc(
+ 'dgpu_online', {
+ 'dgpu_addr': fake_dgpu_addr,
+ 'cert': 'whitelist/testing.cert'
+ })
+ assert 'ok' in res.result.keys()
- # check 1 node is connected
- res = await rpc_call('dgpu_workers')
- assert 'ok' in res.result
- assert res.result['ok'] == 1
+ # check 1 node is connected
+ res = await session.rpc('dgpu_workers')
+ assert 'ok' in res.result.keys()
+ assert res.result['ok'] == 1
- # check next worker is 0
- res = await rpc_call('dgpu_next')
- assert 'ok' in res.result
- assert res.result['ok'] == 0
+ # check next worker is 0
+ res = await session.rpc('dgpu_next')
+ assert 'ok' in res.result.keys()
+ assert res.result['ok'] == 0
- # disconnect 1 dgpu
- res = await rpc_call('dgpu_offline')
- assert 'ok' in res.result
+ # disconnect 1 dgpu
+ res = await session.rpc('dgpu_offline')
+ assert 'ok' in res.result.keys()
# check 0 nodes are connected
- res = await rpc_call('dgpu_workers')
- assert 'ok' in res.result
+ res = await session.rpc('dgpu_workers')
+ assert 'ok' in res.result.keys()
assert res.result['ok'] == 0
# check next worker is None
- res = await rpc_call('dgpu_next')
- assert 'ok' in res.result
+ res = await session.rpc('dgpu_next')
+ assert 'ok' in res.result.keys()
assert res.result['ok'] == None
diff --git a/tests/test_telegram.py b/tests/test_telegram.py
new file mode 100644
index 0000000..d94a6bf
--- /dev/null
+++ b/tests/test_telegram.py
@@ -0,0 +1,28 @@
+#!/usr/bin/python
+
+import trio
+
+from functools import partial
+
+from skynet.db import open_new_database
+from skynet.brain import run_skynet
+from skynet.config import load_skynet_ini
+from skynet.frontend.telegram import run_skynet_telegram
+
+
+if __name__ == '__main__':
+ '''You will need a telegram bot token configured on skynet.ini for this
+ '''
+ with open_new_database() as db_params:
+ db_container, db_pass, db_host = db_params
+ config = load_skynet_ini()
+
+ async def main():
+ await run_skynet_telegram(
+ 'telegram-test',
+ config['skynet.telegram-test']['token'],
+ db_host=db_host,
+ db_pass=db_pass
+ )
+
+ trio.run(main)