diff --git a/.gitignore b/.gitignore index e264fa9..e98d124 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +skynet.ini .python-version hf_home outputs diff --git a/Dockerfile.runtime+cuda b/Dockerfile.runtime+cuda index 32d3c4a..27a5a66 100644 --- a/Dockerfile.runtime+cuda +++ b/Dockerfile.runtime+cuda @@ -32,3 +32,4 @@ env HF_HOME /hf_home copy scripts scripts copy tests tests +expose 40000-45000 diff --git a/LICENSE b/LICENSE index 0fb76a9..fe6b903 100644 --- a/LICENSE +++ b/LICENSE @@ -1,11 +1,662 @@ -A menos que sea especificamente indicado en el cabezal del archivo, se reservan -todos los derechos sobre este codigo por parte de: + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 -Guillermo Rodriguez, guillermor@fing.edu.uy + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. -ENGLISH LICENSE: + Preamble -Unless specifically indicated in the file header, all rights to this code are -reserved by: + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. -Guillermo Rodriguez, guillermor@.edu.uy diff --git a/requirements.test.txt b/requirements.test.txt index ce51dec..f39926b 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -3,4 +3,4 @@ pytest pytest-trio psycopg2-binary -git+https://github.com/guilledk/pytest-dockerctl.git@host_network#egg=pytest-dockerctl +git+https://github.com/guilledk/pytest-dockerctl.git@multi_names#egg=pytest-dockerctl diff --git a/requirements.txt b/requirements.txt index 7afc143..c773225 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ protobuf pyOpenSSL trio_asyncio pyTelegramBotAPI + +git+https://github.com/goodboy/tractor.git@master#egg=tractor diff --git a/skynet.ini.example b/skynet.ini.example new file mode 100644 index 0000000..7035920 --- /dev/null +++ b/skynet.ini.example @@ -0,0 +1,12 @@ +[skynet] +certs_dir = certs + +[skynet.dgpu] +hf_home = hf_home +hf_token = hf_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx + +[skynet.telegram] +token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + +[skynet.telegram-test] +token = XXXXXXXXXX:xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx diff --git a/skynet/brain.py b/skynet/brain.py index c442ba5..b121bd3 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -1,35 +1,24 @@ #!/usr/bin/python -import time -import json -import uuid -import zlib import logging -import traceback -from uuid import UUID -from pathlib import Path -from functools import partial from contextlib import asynccontextmanager as acm from collections import OrderedDict import trio -import pynng -import trio_asyncio -from pynng import TLSConfig -from OpenSSL.crypto import ( - load_privatekey, - load_certificate, - FILETYPE_PEM -) +from pynng import Context -from .db import * +from .utils import time_ms +from .network import * +from .protobuf import * from .constants import * -from .protobuf import * +class SkynetRPCBadRequest(BaseException): + ... + class SkynetDGPUOffline(BaseException): ... @@ -44,39 +33,71 @@ class SkynetShutdownRequested(BaseException): @acm -async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): +async def run_skynet( + rpc_address: str = DEFAULT_RPC_ADDR +): + logging.basicConfig(level=logging.INFO) + logging.info('skynet is starting') + nodes = OrderedDict() - wip_reqs = {} - fin_reqs = {} heartbeats = {} next_worker: Optional[int] = None - security = len(tls_whitelist) > 0 - def connect_node(uid): + def connect_node(req: SkynetRPCRequest): nonlocal next_worker - nodes[uid] = { - 'task': None - } - logging.info(f'dgpu online: {uid}') - if not next_worker: - next_worker = 0 + node_params = MessageToDict(req.params) + logging.info(f'got node params {node_params}') + + if 'dgpu_addr' not in node_params: + raise SkynetRPCBadRequest( + f'DGPU connection params don\'t include dgpu addr') + + session = SessionClient( + node_params['dgpu_addr'], + 'skynet', + cert_name='brain.cert', + key_name='brain.key', + ca_name=node_params['cert'] + ) + try: + session.connect() + + node = { + 'task': None, + 'session': session + } + node.update(node_params) + + nodes[req.uid] = node + logging.info(f'DGPU node online: {req.uid}') + + if not next_worker: + next_worker = 0 + + except pynng.exceptions.ConnectionRefused: + logging.warning(f'error while dialing dgpu node... dropping...') + raise SkynetDGPUOffline('Connection to dgpu node addr failed.') def disconnect_node(uid): nonlocal next_worker if uid not in nodes: + logging.warning(f'Attempt to disconnect unknown node {uid}') return + i = list(nodes.keys()).index(uid) + nodes[uid]['session'].disconnect() del nodes[uid] if i < next_worker: next_worker -= 1 + logging.warning(f'DGPU node offline: {uid}') + if len(nodes) == 0: - logging.info('nw: None') + logging.info('All nodes disconnected.') next_worker = None - logging.warning(f'dgpu offline: {uid}') def is_worker_busy(nid: str): return nodes[nid]['task'] != None @@ -90,8 +111,6 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): def get_next_worker(): nonlocal next_worker - logging.info('get next_worker called') - logging.info(f'pre next_worker: {next_worker}') if next_worker == None: raise SkynetDGPUOffline('No workers connected, try again later') @@ -113,392 +132,79 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): if next_worker >= len(nodes): next_worker = 0 - logging.info(f'post next_worker: {next_worker}') - return nid - async def dgpu_heartbeat_service(): - nonlocal heartbeats - while True: - await trio.sleep(60) - rid = uuid.uuid4().hex - beat_msg = DGPUBusMessage( - rid=rid, - nid='', - method='heartbeat' - ) - heartbeats.clear() - heartbeats[rid] = int(time.time() * 1000) - await dgpu_bus.asend(beat_msg.SerializeToString()) - logging.info('sent heartbeat') - - async def dgpu_bus_streamer(): - nonlocal wip_reqs, fin_reqs, heartbeats - while True: - raw_msg = await dgpu_bus.arecv() - logging.info(f'streamer got {len(raw_msg)} bytes.') - msg = DGPUBusMessage() - msg.ParseFromString(raw_msg) - - if security: - verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert]) - - rid = msg.rid - - if msg.method == 'heartbeat': - sent_time = heartbeats[rid] - delta = msg.params['time'] - sent_time - logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}') - continue - - if rid not in wip_reqs: - continue - - if msg.method == 'binary-reply': - logging.info('bin reply, recv extra data') - raw_img = await dgpu_bus.arecv() - msg = (msg, raw_img) - - fin_reqs[rid] = msg - event = wip_reqs[rid] - event.set() - del wip_reqs[rid] - - async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None): - nonlocal wip_reqs, fin_reqs, next_worker - nid = get_next_worker() - idx = list(nodes.keys()).index(nid) - logging.info(f'dgpu_stream_one_img {idx}/{len(nodes)} {nid}') - rid = uuid.uuid4().hex - ack_event = trio.Event() - img_event = trio.Event() - wip_reqs[rid] = ack_event - - nodes[nid]['task'] = rid - - dgpu_req = DGPUBusMessage( - rid=rid, - nid=nid, - method='diffuse') - dgpu_req.params.update(req.to_dict()) - - if security: - dgpu_req.auth.cert = 'skynet' - dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key) - - msg = dgpu_req.SerializeToString() - if img_buf: - logging.info(f'sending img of size {len(img_buf)} as attachment') - logging.info(img_buf[:10]) - msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf - - await dgpu_bus.asend(msg) - - with trio.move_on_after(4): - await ack_event.wait() - - logging.info(f'ack event: {ack_event.is_set()}') - - if not ack_event.is_set(): - disconnect_node(nid) - raise SkynetDGPUOffline('dgpu failed to acknowledge request') - - ack_msg = fin_reqs[rid] - if 'ack' not in ack_msg.params: - disconnect_node(nid) - raise SkynetDGPUOffline('dgpu failed to acknowledge request') - - wip_reqs[rid] = img_event - with trio.move_on_after(30): - await img_event.wait() - - logging.info(f'img event: {ack_event.is_set()}') - - if not img_event.is_set(): - disconnect_node(nid) - raise SkynetDGPUComputeError('30 seconds timeout while processing request') - - nodes[nid]['task'] = None - - resp = fin_reqs[rid] - del fin_reqs[rid] - if isinstance(resp, tuple): - meta, img = resp - return rid, img, meta.params - - raise SkynetDGPUComputeError(MessageToDict(resp.params)) - - - async def handle_user_request(rpc_ctx, req): - try: - async with db_pool.acquire() as conn: - user = await get_or_create_user(conn, req.uid) - - result = {} - - match req.method: - case 'txt2img': - logging.info('txt2img') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - user_config.update(MessageToDict(req.params)) - - req = DiffusionParameters(**user_config, image=False) - rid, img, meta = await dgpu_stream_one_img(req) - logging.info(f'done streaming {rid}') - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - - await update_user_stats(conn, user, last_prompt=user_config['prompt']) - logging.info('updated user stats.') - - case 'img2img': - logging.info('img2img') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - - params = MessageToDict(req.params) - img_buf = bytes.fromhex(params['img']) - del params['img'] - user_config.update(params) - - req = DiffusionParameters(**user_config, image=True) - - if not req.image: - raise AssertionError('Didn\'t enable image flag for img2img?') - - rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf) - logging.info(f'done streaming {rid}') - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - - await update_user_stats(conn, user, last_prompt=user_config['prompt']) - logging.info('updated user stats.') - - case 'redo': - logging.info('redo') - user_config = {**(await get_user_config(conn, user))} - del user_config['id'] - prompt = await get_last_prompt_of(conn, user) - - if prompt: - req = DiffusionParameters( - prompt=prompt, - **user_config, - image=False - ) - rid, img, meta = await dgpu_stream_one_img(req) - result = { - 'id': rid, - 'img': img.hex(), - 'meta': meta - } - await update_user_stats(conn, user) - logging.info('updated user stats.') - - else: - result = { - 'error': 'skynet_no_last_prompt', - 'message': 'No prompt to redo, do txt2img first' - } - - case 'config': - logging.info('config') - if req.params['attr'] in CONFIG_ATTRS: - logging.info(f'update: {req.params}') - await update_user_config( - conn, user, req.params['attr'], req.params['val']) - logging.info('done') - - else: - logging.warning(f'{req.params["attr"]} not in {CONFIG_ATTRS}') - - case 'stats': - logging.info('stats') - generated, joined, role = await get_user_stats(conn, user) - - result = { - 'generated': generated, - 'joined': joined.strftime(DATE_FORMAT), - 'role': role - } - - case _: - logging.warn('unknown method') - - except SkynetDGPUOffline as e: - result = { - 'error': 'skynet_dgpu_offline', - 'message': str(e) - } - - except SkynetDGPUOverloaded as e: - result = { - 'error': 'skynet_dgpu_overloaded', - 'message': str(e), - 'nodes': len(nodes) - } - - except SkynetDGPUComputeError as e: - result = { - 'error': 'skynet_dgpu_compute_error', - 'message': str(e) - } - except BaseException as e: - traceback.print_exception(type(e), e, e.__traceback__) - result = { - 'error': 'skynet_internal_error', - 'message': str(e) - } - + async def rpc_handler(req: SkynetRPCRequest, ctx: Context): + result = {'ok': {}} resp = SkynetRPCResponse() - resp.result.update(result) - - if security: - resp.auth.cert = 'skynet' - resp.auth.sig = sign_protobuf_msg(resp, tls_key) - - logging.info('sending response') - await rpc_ctx.asend(resp.SerializeToString()) - rpc_ctx.close() - logging.info('done') - - async def request_service(n): - nonlocal next_worker - while True: - ctx = sock.new_context() - req = SkynetRPCRequest() - req.ParseFromString(await ctx.arecv()) - - if security: - if req.auth.cert not in tls_whitelist: - logging.warning( - f'{req.cert} not in tls whitelist and security=True') - continue - - try: - verify_protobuf_msg(req, tls_whitelist[req.auth.cert]) - - except ValueError: - logging.warning( - f'{req.cert} sent an unauthenticated msg with security=True') - continue - - result = {} + try: match req.method: - case 'skynet_shutdown': - raise SkynetShutdownRequested - case 'dgpu_online': - connect_node(req.uid) + connect_node(req) + + case 'dgpu_call': + nid = get_next_worker() + idx = list(nodes.keys()).index(nid) + node = nodes[nid] + logging.info(f'dgpu_call {idx}/{len(nodes)} {nid} @ {node["dgpu_addr"]}') + dgpu_time = await node['session'].rpc('dgpu_time') + if 'ok' not in dgpu_time.result: + status = MessageToDict(dgpu_time.result) + logging.warning(json.dumps(status, indent=4)) + disconnect_node(nid) + raise SkynetDGPUComputeError(status['error']) + + dgpu_time = dgpu_time.result['ok'] + logging.info(f'ping to {nid}: {time_ms() - dgpu_time} ms') + + try: + dgpu_result = await node['session'].rpc( + timeout=45, # give this 45 sec to run cause its compute + binext=req.bin, + **req.params + ) + result = MessageToDict(dgpu_result.result) + + if dgpu_result.bin: + resp.bin = dgpu_result.bin + + except trio.TooSlowError: + result = {'error': 'timeout while processing request'} case 'dgpu_offline': disconnect_node(req.uid) case 'dgpu_workers': - result = len(nodes) + result = {'ok': len(nodes)} case 'dgpu_next': - result = next_worker + result = {'ok': next_worker} - case 'heartbeat': - logging.info('beat') - result = {'time': time.time()} + case 'skynet_shutdown': + raise SkynetShutdownRequested case _: - n.start_soon( - handle_user_request, ctx, req) - continue + logging.warning(f'Unknown method {req.method}') + result = {'error': 'unknown method'} - resp = SkynetRPCResponse() - resp.result.update({'ok': result}) + except BaseException as e: + result = {'error': str(e)} - if security: - resp.auth.cert = 'skynet' - resp.auth.sig = sign_protobuf_msg(resp, tls_key) + resp.result.update(result) - await ctx.asend(resp.SerializeToString()) + return resp - ctx.close() + rpc_server = SessionServer( + rpc_address, + rpc_handler, + cert_name='brain.cert', + key_name='brain.key' + ) - - async with trio.open_nursery() as n: - n.start_soon(dgpu_bus_streamer) - n.start_soon(dgpu_heartbeat_service) - n.start_soon(request_service, n) - logging.info('starting rpc service') + async with rpc_server.open(): + logging.info('rpc server is up') yield - logging.info('stopping rpc service') - n.cancel_scope.cancel() + logging.info('skynet is shuting down...') - -@acm -async def run_skynet( - db_user: str = DB_USER, - db_pass: str = DB_PASS, - db_host: str = DB_HOST, - rpc_address: str = DEFAULT_RPC_ADDR, - dgpu_address: str = DEFAULT_DGPU_ADDR, - security: bool = True -): - logging.basicConfig(level=logging.INFO) - logging.info('skynet is starting') - - tls_config = None - if security: - # load tls certs - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - tls_key_data = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - tls_cert_data = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text() - tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) - - tls_whitelist = {} - for cert_path in (certs_dir / 'whitelist').glob('*.cert'): - tls_whitelist[cert_path.stem] = load_certificate( - FILETYPE_PEM, cert_path.read_text()) - - cert_start = tls_cert_data.index('\n') + 1 - logging.info(f'tls_cert: {tls_cert_data[cert_start:cert_start+64]}...') - logging.info(f'tls_whitelist len: {len(tls_whitelist)}') - - rpc_address = 'tls+' + rpc_address - dgpu_address = 'tls+' + dgpu_address - tls_config = TLSConfig( - TLSConfig.MODE_SERVER, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data) - - with ( - pynng.Rep0(recv_max_size=0) as rpc_sock, - pynng.Bus0(recv_max_size=0) as dgpu_bus - ): - async with open_database_connection( - db_user, db_pass, db_host) as db_pool: - - logging.info('connected to db.') - if security: - rpc_sock.tls_config = tls_config - dgpu_bus.tls_config = tls_config - - rpc_sock.listen(rpc_address) - dgpu_bus.listen(dgpu_address) - - try: - async with open_rpc_service( - rpc_sock, dgpu_bus, db_pool, tls_whitelist, tls_key): - yield - - except SkynetShutdownRequested: - ... - - logging.info('disconnected from db.') + logging.info('skynet down.') diff --git a/skynet/cli.py b/skynet/cli.py index 2573106..021e1e6 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -17,8 +17,8 @@ if torch_enabled: from .dgpu import open_dgpu_node from .brain import run_skynet +from .config import * from .constants import ALGOS, DEFAULT_RPC_ADDR, DEFAULT_DGPU_ADDR - from .frontend.telegram import run_skynet_telegram @@ -38,8 +38,8 @@ def skynet(*args, **kwargs): @click.option('--steps', '-s', default=26) @click.option('--seed', '-S', default=None) def txt2img(*args, **kwargs): - assert 'HF_TOKEN' in os.environ - utils.txt2img(os.environ['HF_TOKEN'], **kwargs) + _, hf_token, _, cfg = init_env_from_config() + utils.txt2img(hf_token, **kwargs) @click.command() @click.option('--model', '-m', default='midj') @@ -52,9 +52,9 @@ def txt2img(*args, **kwargs): @click.option('--steps', '-s', default=26) @click.option('--seed', '-S', default=None) def img2img(model, prompt, input, output, strength, guidance, steps, seed): - assert 'HF_TOKEN' in os.environ + _, hf_token, _, cfg = init_env_from_config() utils.img2img( - os.environ['HF_TOKEN'], + hf_token, model=model, prompt=prompt, img_path=input, @@ -76,6 +76,12 @@ def upscale(input, output, model): model_path=model) +@skynet.command() +def download(): + _, hf_token, _, cfg = init_env_from_config() + utils.download_all_models(hf_token) + + @skynet.group() def run(*args, **kwargs): pass @@ -85,29 +91,17 @@ def run(*args, **kwargs): @click.option('--loglevel', '-l', default='warning', help='Logging level') @click.option( '--host', '-H', default=DEFAULT_RPC_ADDR) -@click.option( - '--host-dgpu', '-D', default=DEFAULT_DGPU_ADDR) -@click.option( - '--db-host', '-h', default='localhost:5432') -@click.option( - '--db-pass', '-p', default='password') def brain( loglevel: str, - host: str, - host_dgpu: str, - db_host: str, - db_pass: str + host: str ): async def _run_skynet(): async with run_skynet( - db_host=db_host, - db_pass=db_pass, - rpc_address=host, - dgpu_address=host_dgpu + rpc_address=host ): await trio.sleep_forever() - trio_asyncio.run(_run_skynet) + trio.run(_run_skynet) @run.command() @@ -115,9 +109,9 @@ def brain( @click.option( '--uid', '-u', required=True) @click.option( - '--key', '-k', default='dgpu') + '--key', '-k', default='dgpu.key') @click.option( - '--cert', '-c', default='whitelist/dgpu') + '--cert', '-c', default='whitelist/dgpu.cert') @click.option( '--algos', '-a', default=json.dumps(['midj'])) @click.option( @@ -159,11 +153,11 @@ def telegram( cert: str, rpc: str ): - assert 'TG_TOKEN' in os.environ + _, _, tg_token, cfg = init_env_from_config() trio_asyncio.run( partial( run_skynet_telegram, - os.environ['TG_TOKEN'], + tg_token, key_name=key, cert_name=cert, rpc_address=rpc diff --git a/skynet/config.py b/skynet/config.py new file mode 100644 index 0000000..91d6101 --- /dev/null +++ b/skynet/config.py @@ -0,0 +1,39 @@ +#!/usr/bin/python + +import os + +from pathlib import Path +from configparser import ConfigParser + +from .constants import DEFAULT_CONFIG_PATH + + +def load_skynet_ini( + file_path=DEFAULT_CONFIG_PATH +): + config = ConfigParser() + config.read(file_path) + return config + + +def init_env_from_config( + file_path=DEFAULT_CONFIG_PATH +): + config = load_skynet_ini() + + if 'HF_TOKEN' in os.environ: + hf_token = os.environ['HF_TOKEN'] + else: + hf_token = config['skynet.dgpu']['hf_token'] + + if 'HF_HOME' in os.environ: + hf_home = os.environ['HF_HOME'] + else: + hf_home = config['skynet.dgpu']['hf_home'] + + if 'TG_TOKEN' in os.environ: + tg_token = os.environ['TG_TOKEN'] + else: + tg_token = config['skynet.telegram']['token'] + + return hf_home, hf_token, tg_token, config diff --git a/skynet/constants.py b/skynet/constants.py index 1478269..3d96a2c 100644 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -1,14 +1,9 @@ #!/usr/bin/python -VERSION = '0.1a8' +VERSION = '0.1a9' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' -DB_HOST = 'localhost:5432' -DB_USER = 'skynet' -DB_PASS = 'password' -DB_NAME = 'skynet' - ALGOS = { 'midj': 'prompthero/openjourney', 'stable': 'runwayml/stable-diffusion-v1-5', @@ -118,6 +113,7 @@ DEFAULT_ALGO = 'midj' DEFAULT_ROLE = 'pleb' DEFAULT_UPSCALER = None +DEFAULT_CONFIG_PATH = 'skynet.ini' DEFAULT_CERTS_DIR = 'certs' DEFAULT_CERT_WHITELIST_DIR = 'whitelist' DEFAULT_CERT_SKYNET_PUB = 'brain.cert' diff --git a/skynet/db/__init__.py b/skynet/db/__init__.py new file mode 100644 index 0000000..fd45c9e --- /dev/null +++ b/skynet/db/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python + +from .proxy import open_database_connection + +from .functions import open_new_database diff --git a/skynet/db.py b/skynet/db/functions.py similarity index 73% rename from skynet/db.py rename to skynet/db/functions.py index fbcf202..10863c2 100644 --- a/skynet/db.py +++ b/skynet/db/functions.py @@ -1,18 +1,21 @@ #!/usr/bin/python +import time +import random +import string import logging from typing import Optional from datetime import datetime -from contextlib import asynccontextmanager as acm +from contextlib import contextmanager as cm -import trio -import triopg -import trio_asyncio +import docker +import psycopg2 from asyncpg.exceptions import UndefinedColumnError +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from .constants import * +from ..constants import * DB_INIT_SQL = ''' @@ -75,29 +78,67 @@ def try_decode_uid(uid: str): return None, None -@acm -async def open_database_connection( - db_user: str = DB_USER, - db_pass: str = DB_PASS, - db_host: str = DB_HOST, - db_name: str = DB_NAME -): - async with trio_asyncio.open_loop() as loop: - async with triopg.create_pool( - dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}' - ) as pool_conn: - async with pool_conn.acquire() as conn: - res = await conn.execute(f''' - select distinct table_schema - from information_schema.tables - where table_schema = \'{db_name}\' - ''') - if '1' in res: - logging.info('schema already in db, skipping init') - else: - await conn.execute(DB_INIT_SQL) +@cm +def open_new_database(): + rpassword = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) + password = ''.join( + random.choice(string.ascii_lowercase) + for i in range(12)) - yield pool_conn + dclient = docker.from_env() + + container = dclient.containers.run( + 'postgres', + name='skynet-test-postgres', + ports={'5432/tcp': None}, + environment={ + 'POSTGRES_PASSWORD': rpassword + }, + detach=True, + remove=True + ) + + for log in container.logs(stream=True): + log = log.decode().rstrip() + logging.info(log) + if ('database system is ready to accept connections' in log or + 'database system is shut down' in log): + break + + # ip = container.attrs['NetworkSettings']['IPAddress'] + container.reload() + port = container.ports['5432/tcp'][0]['HostPort'] + host = f'localhost:{port}' + + # why print the system is ready to accept connections when its not + # postgres? wtf + time.sleep(1) + logging.info('creating skynet db...') + + conn = psycopg2.connect( + user='postgres', + password=rpassword, + host='localhost', + port=port + ) + logging.info('connected...') + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + with conn.cursor() as cursor: + cursor.execute( + f'CREATE USER skynet WITH PASSWORD \'{password}\'') + cursor.execute( + f'CREATE DATABASE skynet') + cursor.execute( + f'GRANT ALL PRIVILEGES ON DATABASE skynet TO skynet') + + conn.close() + + logging.info('done.') + yield container, password, host + + container.stop() async def get_user(conn, uid: str): diff --git a/skynet/db/proxy.py b/skynet/db/proxy.py new file mode 100644 index 0000000..d2f86c1 --- /dev/null +++ b/skynet/db/proxy.py @@ -0,0 +1,123 @@ +#!/usr/bin/python + +import importlib + +from contextlib import asynccontextmanager as acm + +import trio +import tractor +import asyncpg +import asyncio +import trio_asyncio + + +_spawn_kwargs = { + 'infect_asyncio': True, +} + + +async def aio_db_proxy( + to_trio: trio.MemorySendChannel, + from_trio: asyncio.Queue, + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +) -> None: + db = importlib.import_module('skynet.db.functions') + + pool = await asyncpg.create_pool( + dsn=f'postgres://{db_user}:{db_pass}@{db_host}/{db_name}') + + async with pool_conn.acquire() as conn: + res = await conn.execute(f''' + select distinct table_schema + from information_schema.tables + where table_schema = \'{db_name}\' + ''') + if '1' in res: + logging.info('schema already in db, skipping init') + else: + await conn.execute(DB_INIT_SQL) + + # a first message must be sent **from** this ``asyncio`` + # task or the ``trio`` side will never unblock from + # ``tractor.to_asyncio.open_channel_from():`` + to_trio.send_nowait('start') + + # XXX: this uses an ``from_trio: asyncio.Queue`` currently but we + # should probably offer something better. + while True: + msg = await from_trio.get() + + method = getattr(db, msg.get('method')) + args = getattr(db, msg.get('args', [])) + kwargs = getattr(db, msg.get('kwargs', {})) + + async with pool_conn.acquire() as conn: + result = await method(conn, *args, **kwargs) + to_trio.send_nowait(result) + + +@tractor.context +async def trio_to_aio_db_proxy( + ctx: tractor.Context, + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +): + # this will block until the ``asyncio`` task sends a "first" + # message. + async with tractor.to_asyncio.open_channel_from( + aio_db_proxy, + db_user=db_user, + db_pass=db_pass, + db_host=db_host, + db_name=db_name + ) as (first, chan): + + assert first == 'start' + await ctx.started(first) + + async with ctx.open_stream() as stream: + + async for msg in stream: + await chan.send(msg) + + out = await chan.receive() + # echo back to parent actor-task + await stream.send(out) + + +@acm +async def open_database_connection( + db_user: str = 'skynet', + db_pass: str = 'password', + db_host: str = 'localhost:5432', + db_name: str = 'skynet' +): + async with tractor.open_nursery() as n: + p = await n.start_actor( + 'aio_db_proxy', + enable_modules=[__name__], + infect_asyncio=True, + ) + async with p.open_context( + trio_to_aio_db_proxy, + db_user=db_user, + db_pass=db_pass, + db_host=db_host, + db_name=db_name + ) as (ctx, first): + async with ctx.open_stream() as stream: + + async def _db_pc(method: str, *args, **kwargs): + await stream.send({ + 'method': method, + 'args': args, + 'kwargs': kwargs + }) + return await stream.receive() + + yield _db_pc diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 752c8b8..79c6c49 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -2,29 +2,17 @@ import gc import io -import trio import json -import uuid -import time -import zlib import random import logging -import traceback from PIL import Image from typing import List, Optional -from pathlib import Path -from contextlib import ExitStack -import pynng +import trio import torch -from pynng import TLSConfig -from OpenSSL.crypto import ( - load_privatekey, - load_certificate, - FILETYPE_PEM -) +from pynng import Context from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, @@ -34,12 +22,9 @@ from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet from diffusers.models import UNet2DConditionModel -from .utils import ( - pipeline_for, - convert_from_cv2_to_image, convert_from_image_to_cv2 -) +from .utils import * +from .network import * from .protobuf import * -from .frontend import open_skynet_rpc from .constants import * @@ -64,65 +49,16 @@ class DGPUComputeError(BaseException): ... -class ReconnectingBus: - - def __init__(self, address: str, tls_config: Optional[TLSConfig]): - self.address = address - self.tls_config = tls_config - - self._stack = ExitStack() - self._sock = None - self._closed = True - - def connect(self): - self._sock = self._stack.enter_context( - pynng.Bus0(recv_max_size=0)) - self._sock.tls_config = self.tls_config - self._sock.dial(self.address) - self._closed = False - - async def arecv(self): - while True: - try: - return await self._sock.arecv() - - except pynng.exceptions.Closed: - if self._closed: - raise - - async def asend(self, msg): - while True: - try: - return await self._sock.asend(msg) - - except pynng.exceptions.Closed: - if self._closed: - raise - - def close(self): - self._stack.close() - self._stack = ExitStack() - self._closed = True - - def reconnect(self): - self.close() - self.connect() - - async def open_dgpu_node( cert_name: str, unique_id: str, key_name: Optional[str], rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, - initial_algos: Optional[List[str]] = None, - security: bool = True + initial_algos: Optional[List[str]] = None ): - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) logging.info(f'starting dgpu node!') - - name = uuid.uuid4() - logging.info(f'loading models...') upscaler = init_upscaler() @@ -141,241 +77,140 @@ async def open_dgpu_node( logging.info('memory summary:') logging.info('\n' + torch.cuda.memory_summary()) - async def gpu_compute_one(ireq: DiffusionParameters, image=None): - algo = ireq.algo + 'img' if image else ireq.algo - if algo not in models: - least_used = list(models.keys())[0] - for model in models: - if models[least_used]['generated'] > models[model]['generated']: - least_used = model + async def gpu_compute_one(method: str, params: dict, binext: Optional[bytes] = None): + match method: + case 'diffuse': + image = None + algo = params['algo'] + if binext: + algo += 'img' + image = Image.open(io.BytesIO(binext)) + w, h = image.size + logging.info(f'user sent img of size {image.size}') - del models[least_used] - gc.collect() + if w > 512 or h > 512: + image.thumbnail((512, 512)) + logging.info(f'resized it to {image.size}') - models[algo] = { - 'pipe': pipeline_for(ireq.algo, image=True if image else False), - 'generated': 0 - } + if algo not in models: + logging.info(f'{algo} not in loaded models, swapping...') + least_used = list(models.keys())[0] + for model in models: + if models[least_used]['generated'] > models[model]['generated']: + least_used = model - _params = {} - if ireq.image: - _params['image'] = image - _params['strength'] = ireq.strength + del models[least_used] + gc.collect() - else: - _params['width'] = int(ireq.width) - _params['height'] = int(ireq.height) + models[algo] = { + 'pipe': pipeline_for(params['algo'], image=True if binext else False), + 'generated': 0 + } + logging.info(f'swapping done.') - try: - image = models[algo]['pipe']( - ireq.prompt, - **_params, - guidance_scale=ireq.guidance, - num_inference_steps=int(ireq.step), - generator=torch.Generator("cuda").manual_seed(ireq.seed) - ).images[0] + _params = {} + logging.info(method) + logging.info(json.dumps(params, indent=4)) + logging.info(f'binext: {len(binext) if binext else 0} bytes') + if binext: + _params['image'] = image + _params['strength'] = params['strength'] - if ireq.upscaler == 'x4': - logging.info(f'size: {len(image.tobytes())}') - logging.info('performing upscale...') - input_img = image.convert('RGB') - up_img, _ = upscaler.enhance( - convert_from_image_to_cv2(input_img), outscale=4) + else: + _params['width'] = int(params['width']) + _params['height'] = int(params['height']) - image = convert_from_cv2_to_image(up_img) - logging.info('done') + try: + image = models[algo]['pipe']( + params['prompt'], + **_params, + guidance_scale=params['guidance'], + num_inference_steps=int(params['step']), + generator=torch.Generator("cuda").manual_seed( + int(params['seed']) if params['seed'] else random.randint(0, 2 ** 64) + ) + ).images[0] - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') - raw_img = img_byte_arr.getvalue() - logging.info(f'final img size {len(raw_img)} bytes.') + if params['upscaler'] == 'x4': + logging.info(f'size: {len(image.tobytes())}') + logging.info('performing upscale...') + input_img = image.convert('RGB') + up_img, _ = upscaler.enhance( + convert_from_image_to_cv2(input_img), outscale=4) - return raw_img + image = convert_from_cv2_to_image(up_img) + logging.info('done') - except BaseException as e: - logging.error(e) - raise DGPUComputeError(str(e)) + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + raw_img = img_byte_arr.getvalue() + logging.info(f'final img size {len(raw_img)} bytes.') - finally: - torch.cuda.empty_cache() + return raw_img + + except BaseException as e: + logging.error(e) + raise DGPUComputeError(str(e)) + + finally: + torch.cuda.empty_cache() + + case _: + raise DGPUComputeError('Unsupported compute method') + + async def rpc_handler(req: SkynetRPCRequest, ctx: Context): + result = {} + resp = SkynetRPCResponse() + + match req.method: + case 'dgpu_time': + result = {'ok': time_ms()} + + case _: + logging.debug(f'dgpu got one request: {req.method}') + try: + resp.bin = await gpu_compute_one( + req.method, MessageToDict(req.params), + binext=req.bin if req.bin else None + ) + logging.debug(f'dgpu processed one request') + + except DGPUComputeError as e: + result = {'error': str(e)} + + resp.result.update(result) + return resp + + rpc_server = SessionServer( + dgpu_address, + rpc_handler, + cert_name=cert_name, + key_name=key_name + ) + skynet_rpc = SessionClient( + rpc_address, + unique_id, + cert_name=cert_name, + key_name=key_name + ) + skynet_rpc.connect() - async with ( - open_skynet_rpc( - unique_id, - rpc_address=rpc_address, - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call, - trio.open_nursery() as n - ): + async with rpc_server.open() as rpc_server: + res = await skynet_rpc.rpc( + 'dgpu_online', { + 'dgpu_addr': rpc_server.addr, + 'cert': cert_name + }) - tls_config = None - if security: - # load tls certs - if not key_name: - key_name = cert_name - - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - skynet_cert_path = certs_dir / 'brain.cert' - tls_cert_path = certs_dir / f'{cert_name}.cert' - tls_key_path = certs_dir / f'{key_name}.key' - - cert_name = tls_cert_path.stem - - skynet_cert_data = skynet_cert_path.read_text() - skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) - - tls_cert_data = tls_cert_path.read_text() - - tls_key_data = tls_key_path.read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - logging.info(f'skynet cert: {skynet_cert_path}') - logging.info(f'dgpu cert: {tls_cert_path}') - logging.info(f'dgpu key: {tls_key_path}') - - dgpu_address = 'tls+' + dgpu_address - tls_config = TLSConfig( - TLSConfig.MODE_CLIENT, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data, - ca_string=skynet_cert_data) - - logging.info(f'connecting to {dgpu_address}') - - dgpu_bus = ReconnectingBus(dgpu_address, tls_config) - dgpu_bus.connect() - - last_msg = time.time() - async def connection_refresher(refresh_time: int = 120): - nonlocal last_msg - while True: - now = time.time() - last_msg_time_delta = now - last_msg - logging.info(f'time since last msg: {last_msg_time_delta}') - if last_msg_time_delta > refresh_time: - dgpu_bus.reconnect() - logging.info('reconnected!') - last_msg = now - - await trio.sleep(refresh_time) - - n.start_soon(connection_refresher) - - res = await rpc_call('dgpu_online') assert 'ok' in res.result try: - while True: - msg = await dgpu_bus.arecv() - - img = None - if b'BINEXT' in msg: - header, msg, img_raw = msg.split(b'%$%$') - logging.info(f'got img attachment of size {len(img_raw)}') - logging.info(img_raw[:10]) - raw_img = zlib.decompress(img_raw) - logging.info(raw_img[:10]) - img = Image.open(io.BytesIO(raw_img)) - w, h = img.size - logging.info(f'user sent img of size {img.size}') - - if w > 512 or h > 512: - img.thumbnail((512, 512)) - logging.info(f'resized it to {img.size}') - - - req = DGPUBusMessage() - req.ParseFromString(msg) - last_msg = time.time() - - if req.method == 'heartbeat': - rep = DGPUBusMessage( - rid=req.rid, - nid=unique_id, - method=req.method - ) - rep.params.update({'time': int(time.time() * 1000)}) - - if security: - rep.auth.cert = cert_name - rep.auth.sig = sign_protobuf_msg(rep, tls_key) - - await dgpu_bus.asend(rep.SerializeToString()) - logging.info('heartbeat reply') - continue - - if req.nid != unique_id: - logging.info( - f'witnessed msg {req.rid}, node involved: {req.nid}') - continue - - if security: - verify_protobuf_msg(req, skynet_cert) - - - ack_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid - ) - ack_resp.params.update({'ack': {}}) - - if security: - ack_resp.auth.cert = cert_name - ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) - - # send ack - await dgpu_bus.asend(ack_resp.SerializeToString()) - - logging.info(f'sent ack, processing {req.rid}...') - - try: - img_req = DiffusionParameters(**req.params) - - if not img_req.seed: - img_req.seed = random.randint(0, 2 ** 64) - - img = await gpu_compute_one(img_req, image=img) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid, - method='binary-reply' - ) - img_resp.params.update({ - 'len': len(img), - 'meta': img_req.to_dict() - }) - - except DGPUComputeError as e: - traceback.print_exception(type(e), e, e.__traceback__) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid - ) - img_resp.params.update({'error': str(e)}) - - - if security: - img_resp.auth.cert = cert_name - img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) - - # send final image - logging.info('sending img back...') - raw_msg = img_resp.SerializeToString() - await dgpu_bus.asend(raw_msg) - logging.info(f'sent {len(raw_msg)} bytes.') - if img_resp.method == 'binary-reply': - await dgpu_bus.asend(zlib.compress(img)) - logging.info(f'sent {len(img)} bytes.') + await trio.sleep_forever() except KeyboardInterrupt: logging.info('interrupt caught, stopping...') - n.cancel_scope.cancel() - dgpu_bus.close() finally: - res = await rpc_call('dgpu_offline') + res = await skynet_rpc.rpc('dgpu_offline') assert 'ok' in res.result diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index f8193a2..04d6b90 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -4,7 +4,7 @@ import json from typing import Union, Optional from pathlib import Path -from contextlib import asynccontextmanager as acm +from contextlib import contextmanager as cm import pynng @@ -17,6 +17,7 @@ from OpenSSL.crypto import ( from google.protobuf.struct_pb2 import Struct +from ..network import SessionClient from ..constants import * from ..protobuf.auth import * @@ -39,75 +40,23 @@ class ConfigSizeDivisionByEight(BaseException): ... -@acm -async def open_skynet_rpc( +@cm +def open_skynet_rpc( unique_id: str, rpc_address: str = DEFAULT_RPC_ADDR, - security: bool = False, cert_name: Optional[str] = None, key_name: Optional[str] = None ): - tls_config = None - - if security: - # load tls certs - if not key_name: - key_name = cert_name - - certs_dir = Path(DEFAULT_CERTS_DIR).resolve() - - skynet_cert_data = (certs_dir / 'brain.cert').read_text() - skynet_cert = load_certificate(FILETYPE_PEM, skynet_cert_data) - - tls_cert_path = certs_dir / f'{cert_name}.cert' - tls_cert_data = tls_cert_path.read_text() - tls_cert = load_certificate(FILETYPE_PEM, tls_cert_data) - cert_name = tls_cert_path.stem - - tls_key_data = (certs_dir / f'{key_name}.key').read_text() - tls_key = load_privatekey(FILETYPE_PEM, tls_key_data) - - rpc_address = 'tls+' + rpc_address - tls_config = TLSConfig( - TLSConfig.MODE_CLIENT, - own_key_string=tls_key_data, - own_cert_string=tls_cert_data, - ca_string=skynet_cert_data) - - with pynng.Req0(recv_max_size=0) as sock: - if security: - sock.tls_config = tls_config - - sock.dial(rpc_address) - - async def _rpc_call( - method: str, - params: dict = {}, - uid: Optional[str] = None - ): - req = SkynetRPCRequest() - req.uid = uid if uid else unique_id - req.method = method - req.params.update(params) - - if security: - req.auth.cert = cert_name - req.auth.sig = sign_protobuf_msg(req, tls_key) - - ctx = sock.new_context() - await ctx.asend(req.SerializeToString()) - - resp = SkynetRPCResponse() - resp.ParseFromString(await ctx.arecv()) - ctx.close() - - if security: - verify_protobuf_msg(resp, skynet_cert) - - return resp - - yield _rpc_call - + sesh = SessionClient( + rpc_address, + unique_id, + cert_name=cert_name, + key_name=key_name + ) + logging.debug(f'opening skynet rpc...') + sesh.connect() + yield sesh + sesh.disconnect() def validate_user_config_request(req: str): params = req.split(' ') diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 3287b3a..65a6fcb 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -6,8 +6,6 @@ import logging from datetime import datetime -import pynng - from PIL import Image from trio_asyncio import aio_as_trio @@ -16,6 +14,7 @@ from telebot.types import ( ) from telebot.async_telebot import AsyncTeleBot +from ..db import open_database_connection from ..constants import * from . import * @@ -56,228 +55,274 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str: async def run_skynet_telegram( + name: str, tg_token: str, - key_name: str = 'telegram-frontend', - cert_name: str = 'whitelist/telegram-frontend', - rpc_address: str = DEFAULT_RPC_ADDR + key_name: str = 'telegram-frontend.key', + cert_name: str = 'whitelist/telegram-frontend.cert', + rpc_address: str = DEFAULT_RPC_ADDR, + db_host: str = 'localhost:5432', + db_user: str = 'skynet', + db_pass: str = 'password' ): logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) + logging.info(f'tg_token: {tg_token}') - async with open_skynet_rpc( - 'skynet-telegram-0', - rpc_address=rpc_address, - security=True, - cert_name=cert_name, - key_name=key_name - ) as rpc_call: + async with open_database_connection( + db_user, db_pass, db_host + ) as db_call: + with open_skynet_rpc( + f'skynet-telegram-{name}', + rpc_address=rpc_address, + cert_name=cert_name, + key_name=key_name + ) as session: - async def _rpc_call( - uid: int, - method: str, - params: dict = {} - ): - return await rpc_call( - method, params, uid=f'{PREFIX}+{uid}') + @bot.message_handler(commands=['help']) + async def send_help(message): + splt_msg = message.text.split(' ') - @bot.message_handler(commands=['help']) - async def send_help(message): - splt_msg = message.text.split(' ') - - if len(splt_msg) == 1: - await bot.reply_to(message, HELP_TEXT) - - else: - param = splt_msg[1] - if param in HELP_TOPICS: - await bot.reply_to(message, HELP_TOPICS[param]) + if len(splt_msg) == 1: + await bot.reply_to(message, HELP_TEXT) else: - await bot.reply_to(message, HELP_UNKWNOWN_PARAM) + param = splt_msg[1] + if param in HELP_TOPICS: + await bot.reply_to(message, HELP_TOPICS[param]) - @bot.message_handler(commands=['cool']) - async def send_cool_words(message): - await bot.reply_to(message, '\n'.join(COOL_WORDS)) + else: + await bot.reply_to(message, HELP_UNKWNOWN_PARAM) - @bot.message_handler(commands=['txt2img']) - async def send_txt2img(message): - chat = message.chat + @bot.message_handler(commands=['cool']) + async def send_cool_words(message): + await bot.reply_to(message, '\n'.join(COOL_WORDS)) - prompt = ' '.join(message.text.split(' ')[1:]) + @bot.message_handler(commands=['txt2img']) + async def send_txt2img(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return + user_id = f'tg+{message.from_user.id}' - logging.info(f'mid: {message.id}') - resp = await _rpc_call( - message.from_user.id, - 'txt2img', - {'prompt': prompt} - ) - logging.info(f'resp to {message.id} arrived') + prompt = ' '.join(message.text.split(' ')[1:]) - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) + logging.info(f'mid: {message.id}') + user = await db_call('get_or_create_user', user_id) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] - await bot.send_photo( - GROUP_ID, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), - photo=img, - reply_markup=build_redo_menu() + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + timeout=60 ) - return + logging.info(f'resp to {message.id} arrived') - await bot.reply_to(message, resp_txt) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - @bot.message_handler(func=lambda message: True, content_types=['photo']) - async def send_img2img(message): - chat = message.chat + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - if not message.caption.startswith('/img2img'): - return + await bot.send_photo( + GROUP_ID, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), + photo=img, + reply_to_message_id=reply_id, + reply_markup=build_redo_menu() + ) + return - prompt = ' '.join(message.caption.split(' ')[1:]) - if len(prompt) == 0: - await bot.reply_to(message, 'Empty text prompt ignored.') - return + @bot.message_handler(func=lambda message: True, content_types=['photo']) + async def send_img2img(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - file_id = message.photo[-1].file_id - file_path = (await bot.get_file(file_id)).file_path - file_raw = await bot.download_file(file_path) - img = zlib.compress(file_raw) + user_id = f'tg+{message.from_user.id}' - logging.info(f'mid: {message.id}') - resp = await _rpc_call( - message.from_user.id, - 'img2img', - {'prompt': prompt, 'img': img.hex()} - ) - logging.info(f'resp to {message.id} arrived') + if not message.caption.startswith('/img2img'): + await bot.reply_to( + message, + 'For image to image you need to add /img2img to the beggining of your caption' + ) + return - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] + prompt = ' '.join(message.caption.split(' ')[1:]) - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return - await bot.send_media_group( - GROUP_ID, - media=[ - InputMediaPhoto(file_id), - InputMediaPhoto( - img, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']) - ) - ] + file_id = message.photo[-1].file_id + file_path = (await bot.get_file(file_id)).file_path + file_raw = await bot.download_file(file_path) + + logging.info(f'mid: {message.id}') + + user = await db_call('get_or_create_user', user_id) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] + + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + binext=file_raw, + timeout=60 ) - return + logging.info(f'resp to {message.id} arrived') - await bot.reply_to(message, resp_txt) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - @bot.message_handler(commands=['img2img']) - async def redo_txt2img(message): - await bot.reply_to( - message, - 'seems you tried to do an img2img command without sending image' - ) + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - async def _redo(message): - resp = await _rpc_call(message.from_user.id, 'redo') + await bot.send_media_group( + GROUP_ID, + media=[ + InputMediaPhoto(file_id), + InputMediaPhoto( + img, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']) + ) + ], + reply_to_message_id=reply_id + ) + return - resp_txt = '' - result = MessageToDict(resp.result) - if 'error' in resp.result: - resp_txt = resp.result['message'] - else: - logging.info(result['id']) - img_raw = zlib.decompress(bytes.fromhex(result['img'])) - logging.info(f'got image of size: {len(img_raw)}') - img = Image.open(io.BytesIO(img_raw)) - - await bot.send_photo( - GROUP_ID, - caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), - photo=img, - reply_markup=build_redo_menu() + @bot.message_handler(commands=['img2img']) + async def img2img_missing_image(message): + await bot.reply_to( + message, + 'seems you tried to do an img2img command without sending image' ) - return - await bot.reply_to(message, resp_txt) + @bot.message_handler(commands=['redo']) + async def redo(message): + chat = message.chat + reply_id = None + if chat.type == 'group' and chat.id == GROUP_ID: + reply_id = message.message_id - @bot.message_handler(commands=['redo']) - async def redo_txt2img(message): - await _redo(message) + user_config = {**(await db_call('get_user_config', user))} + del user_config['id'] + prompt = await db_call('get_last_prompt_of', user) - @bot.message_handler(commands=['config']) - async def set_config(message): - rpc_params = {} - try: - attr, val, reply_txt = validate_user_config_request( - message.text) + resp = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': prompt, + **user_config + } + }, + timeout=60 + ) + logging.info(f'resp to {message.id} arrived') - resp = await _rpc_call( - message.from_user.id, - 'config', {'attr': attr, 'val': val}) + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + await bot.reply_to(message, resp_txt) - except BaseException as e: - reply_txt = str(e) + else: + logging.info(result['id']) + img_raw = resp.bin + logging.info(f'got image of size: {len(img_raw)}') + img = Image.open(io.BytesIO(img_raw)) - finally: - await bot.reply_to(message, reply_txt) + await bot.send_photo( + GROUP_ID, + caption=prepare_metainfo_caption(message.from_user, result['meta']['meta']), + photo=img, + reply_to_message_id=reply_id + ) + return - @bot.message_handler(commands=['stats']) - async def user_stats(message): - resp = await _rpc_call( - message.from_user.id, - 'stats', - {} - ) - stats = resp.result + @bot.message_handler(commands=['config']) + async def set_config(message): + rpc_params = {} + try: + attr, val, reply_txt = validate_user_config_request( + message.text) - stats_str = f'generated: {stats["generated"]}\n' - stats_str += f'joined: {stats["joined"]}\n' - stats_str += f'role: {stats["role"]}\n' + logging.info(f'user config update: {attr} to {val}') + await db_call('update_user_config', + user, req.params['attr'], req.params['val']) + logging.info('done') - await bot.reply_to( - message, stats_str) + except BaseException as e: + reply_txt = str(e) - @bot.message_handler(commands=['donate']) - async def donation_info(message): - await bot.reply_to( - message, DONATION_INFO) + finally: + await bot.reply_to(message, reply_txt) - @bot.message_handler(commands=['say']) - async def say(message): - chat = message.chat - user = message.from_user + @bot.message_handler(commands=['stats']) + async def user_stats(message): - if (chat.type == 'group') or (user.id != 383385940): - return + generated, joined, role = await db_call('get_user_stats', user) - await bot.send_message(GROUP_ID, message.text[4:]) + stats_str = f'generated: {generated}\n' + stats_str += f'joined: {joined}\n' + stats_str += f'role: {role}\n' + + await bot.reply_to( + message, stats_str) + + @bot.message_handler(commands=['donate']) + async def donation_info(message): + await bot.reply_to( + message, DONATION_INFO) + + @bot.message_handler(commands=['say']) + async def say(message): + chat = message.chat + user = message.from_user + + if (chat.type == 'group') or (user.id != 383385940): + return + + await bot.send_message(GROUP_ID, message.text[4:]) - @bot.message_handler(func=lambda message: True) - async def echo_message(message): - if message.text[0] == '/': - await bot.reply_to(message, UNKNOWN_CMD_TEXT) + @bot.message_handler(func=lambda message: True) + async def echo_message(message): + if message.text[0] == '/': + await bot.reply_to(message, UNKNOWN_CMD_TEXT) @bot.callback_query_handler(func=lambda call: True) async def callback_query(call): @@ -289,4 +334,4 @@ async def run_skynet_telegram( await _redo(call) - await aio_as_trio(bot.infinity_polling()) + await aio_as_trio(bot.infinity_polling)() diff --git a/skynet/network.py b/skynet/network.py new file mode 100644 index 0000000..95fb60f --- /dev/null +++ b/skynet/network.py @@ -0,0 +1,341 @@ +#!/usr/bin/python + +import zlib +import socket + +from typing import Callable, Awaitable, Optional +from pathlib import Path +from contextlib import asynccontextmanager as acm +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +import trio +import pynng + +from pynng import TLSConfig, Context + +from .protobuf import * +from .constants import * + + +def get_random_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(('', 0)) + return s.getsockname()[1] + + +def load_certs( + certs_dir: str, + cert_name: str, + key_name: str +): + certs_dir = Path(certs_dir).resolve() + tls_key_data = (certs_dir / key_name).read_bytes() + tls_key = serialization.load_pem_private_key( + tls_key_data, + password=None + ) + + tls_cert_data = (certs_dir / cert_name).read_bytes() + tls_cert = x509.load_pem_x509_certificate( + tls_cert_data + ) + + tls_whitelist = {} + for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'): + tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate( + cert_path.read_bytes() + ) + + return ( + SessionTLSConfig( + TLSConfig.MODE_SERVER, + own_key_string=tls_key_data, + own_cert_string=tls_cert_data + ), + + tls_whitelist + ) + + +def load_certs_client( + certs_dir: str, + cert_name: str, + key_name: str, + ca_name: Optional[str] = None +): + certs_dir = Path(certs_dir).resolve() + if not ca_name: + ca_name = 'brain.cert' + + ca_cert_data = (certs_dir / ca_name).read_bytes() + + tls_key_data = (certs_dir / key_name).read_bytes() + + + tls_cert_data = (certs_dir / cert_name).read_bytes() + + + tls_whitelist = {} + for cert_path in (*(certs_dir / 'whitelist').glob('*.cert'), certs_dir / 'brain.cert'): + tls_whitelist[cert_path.stem] = x509.load_pem_x509_certificate( + cert_path.read_bytes() + ) + + return ( + SessionTLSConfig( + TLSConfig.MODE_CLIENT, + own_key_string=tls_key_data, + own_cert_string=tls_cert_data, + ca_string=ca_cert_data + ), + + tls_whitelist + ) + + +class SessionError(BaseException): + ... + + +class SessionTLSConfig(TLSConfig): + + def __init__( + self, + mode, + server_name=None, + ca_string=None, + own_key_string=None, + own_cert_string=None, + auth_mode=None, + ca_files=None, + cert_key_file=None, + passwd=None + ): + super().__init__( + mode, + server_name=server_name, + ca_string=ca_string, + own_key_string=own_key_string, + own_cert_string=own_cert_string, + auth_mode=auth_mode, + ca_files=ca_files, + cert_key_file=cert_key_file, + passwd=passwd + ) + + if ca_string: + self.ca_cert = x509.load_pem_x509_certificate(ca_string) + + self.cert = x509.load_pem_x509_certificate(own_cert_string) + self.key = serialization.load_pem_private_key( + own_key_string, + password=passwd + ) + + +class SessionServer: + + def __init__( + self, + addr: str, + msg_handler: Callable[ + [SkynetRPCRequest, Context], Awaitable[SkynetRPCResponse] + ], + cert_name: Optional[str] = None, + key_name: Optional[str] = None, + cert_dir: str = DEFAULT_CERTS_DIR, + recv_max_size = 0 + ): + self.addr = addr + self.msg_handler = msg_handler + + self.cert_name = cert_name + self.tls_config = None + self.tls_whitelist = None + if cert_name and key_name: + self.cert_name = cert_name + self.tls_config, self.tls_whitelist = load_certs( + cert_dir, cert_name, key_name) + + self.addr = 'tls+' + self.addr + + self.recv_max_size = recv_max_size + + async def _handle_msg(self, req: SkynetRPCRequest, ctx: Context): + resp = await self.msg_handler(req, ctx) + + if self.tls_config: + resp.auth.cert = 'skynet' + resp.auth.sig = sign_protobuf_msg( + resp, self.tls_config.key) + + raw_msg = zlib.compress(resp.SerializeToString()) + + await ctx.asend(raw_msg) + + ctx.close() + + async def _listener (self, sock): + async with trio.open_nursery() as n: + while True: + ctx = sock.new_context() + + raw_msg = await ctx.arecv() + raw_size = len(raw_msg) + logging.debug(f'rpc server new msg {raw_size} bytes') + + try: + msg = zlib.decompress(raw_msg) + msg_size = len(msg) + + except zlib.error: + logging.warning(f'Zlib decompress error, dropping msg of size {len(raw_msg)}') + continue + + logging.debug(f'msg after decompress {msg_size} bytes, +{msg_size - raw_size} bytes') + + req = SkynetRPCRequest() + try: + req.ParseFromString(msg) + + except google.protobuf.message.DecodeError: + logging.warning(f'Dropping malfomed msg of size {len(msg)}') + continue + + logging.debug(f'msg method: {req.method}') + + if self.tls_config: + if req.auth.cert not in self.tls_whitelist: + logging.warning( + f'{req.auth.cert} not in tls whitelist') + continue + + try: + verify_protobuf_msg(req, self.tls_whitelist[req.auth.cert]) + + except ValueError: + logging.warning( + f'{req.cert} sent an unauthenticated msg') + continue + + n.start_soon(self._handle_msg, req, ctx) + + @acm + async def open(self): + with pynng.Rep0( + recv_max_size=self.recv_max_size + ) as sock: + + if self.tls_config: + sock.tls_config = self.tls_config + + sock.listen(self.addr) + + logging.debug(f'server socket listening at {self.addr}') + + async with trio.open_nursery() as n: + n.start_soon(self._listener, sock) + + try: + yield self + + finally: + n.cancel_scope.cancel() + + logging.debug('server socket is off.') + + +class SessionClient: + + def __init__( + self, + connect_addr: str, + uid: str, + cert_name: Optional[str] = None, + key_name: Optional[str] = None, + ca_name: Optional[str] = None, + cert_dir: str = DEFAULT_CERTS_DIR, + recv_max_size = 0 + ): + self.uid = uid + self.connect_addr = connect_addr + + self.cert_name = None + self.tls_config = None + self.tls_whitelist = None + self.tls_cert = None + self.tls_key = None + if cert_name and key_name: + self.cert_name = Path(cert_name).stem + self.tls_config, self.tls_whitelist = load_certs_client( + cert_dir, cert_name, key_name, ca_name=ca_name) + + if not self.connect_addr.startswith('tls'): + self.connect_addr = 'tls+' + self.connect_addr + + self.recv_max_size = recv_max_size + + self._connected = False + self._sock = None + + def connect(self): + self._sock = pynng.Req0( + recv_max_size=0, + name=self.uid + ) + + if self.tls_config: + self._sock.tls_config = self.tls_config + + logging.debug(f'client is dialing {self.connect_addr}...') + self._sock.dial(self.connect_addr, block=True) + self._connected = True + logging.debug(f'client is connected to {self.connect_addr}') + + def disconnect(self): + self._sock.close() + self._connected = False + logging.debug(f'client disconnected.') + + async def rpc( + self, + method: str, + params: dict = {}, + binext: Optional[bytes] = None, + timeout: float = 2. + ): + if not self._connected: + raise SessionError('tried to use rpc without connecting') + + req = SkynetRPCRequest() + req.uid = self.uid + req.method = method + req.params.update(params) + if binext: + logging.debug('added binary extension') + req.bin = binext + + if self.tls_config: + req.auth.cert = self.cert_name + req.auth.sig = sign_protobuf_msg(req, self.tls_config.key) + + with trio.fail_after(timeout): + ctx = self._sock.new_context() + raw_req = zlib.compress(req.SerializeToString()) + logging.debug(f'rpc client sending new msg {method} of size {len(raw_req)}') + await ctx.asend(raw_req) + logging.debug('sent, awaiting response...') + raw_resp = await ctx.arecv() + logging.debug(f'rpc client got response of size {len(raw_resp)}') + raw_resp = zlib.decompress(raw_resp) + + resp = SkynetRPCResponse() + resp.ParseFromString(raw_resp) + ctx.close() + + if self.tls_config: + verify_protobuf_msg(resp, self.tls_config.ca_cert) + + return resp diff --git a/skynet/protobuf/__init__.py b/skynet/protobuf/__init__.py index b985940..acafec8 100644 --- a/skynet/protobuf/__init__.py +++ b/skynet/protobuf/__init__.py @@ -1,29 +1,4 @@ #!/usr/bin/python -from typing import Optional -from dataclasses import dataclass, asdict - -from google.protobuf.json_format import MessageToDict - from .auth import * from .skynet_pb2 import * - - -class Struct: - - def to_dict(self): - return asdict(self) - - -@dataclass -class DiffusionParameters(Struct): - algo: str - prompt: str - step: int - width: int - height: int - guidance: float - strength: float - seed: Optional[int] - image: bool # if true indicates a bytestream is next msg - upscaler: Optional[str] diff --git a/skynet/protobuf/auth.py b/skynet/protobuf/auth.py index e2904cb..876683d 100644 --- a/skynet/protobuf/auth.py +++ b/skynet/protobuf/auth.py @@ -7,7 +7,8 @@ from hashlib import sha256 from collections import OrderedDict from google.protobuf.json_format import MessageToDict -from OpenSSL.crypto import PKey, X509, verify, sign +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding from .skynet_pb2 import * @@ -46,20 +47,23 @@ def serialize_msg_deterministic(msg): if field_descriptor.message_type.name == 'Struct': hash_dict(MessageToDict(getattr(msg, field_name))) - deterministic_msg = shasum.hexdigest() + deterministic_msg = shasum.digest() return deterministic_msg -def sign_protobuf_msg(msg, key: PKey): - return sign( - key, serialize_msg_deterministic(msg), 'sha256').hex() +def sign_protobuf_msg(msg, key): + return key.sign( + serialize_msg_deterministic(msg), + padding.PKCS1v15(), + hashes.SHA256() + ).hex() -def verify_protobuf_msg(msg, cert: X509): - return verify( - cert, +def verify_protobuf_msg(msg, cert): + return cert.public_key().verify( bytes.fromhex(msg.auth.sig), serialize_msg_deterministic(msg), - 'sha256' + padding.PKCS1v15(), + hashes.SHA256() ) diff --git a/skynet/protobuf/skynet.proto b/skynet/protobuf/skynet.proto index 6e66274..0bdccad 100644 --- a/skynet/protobuf/skynet.proto +++ b/skynet/protobuf/skynet.proto @@ -13,18 +13,12 @@ message SkynetRPCRequest { string uid = 1; string method = 2; google.protobuf.Struct params = 3; - optional Auth auth = 4; + optional bytes bin = 4; + optional Auth auth = 5; } message SkynetRPCResponse { google.protobuf.Struct result = 1; - optional Auth auth = 2; -} - -message DGPUBusMessage { - string rid = 1; - string nid = 2; - string method = 3; - google.protobuf.Struct params = 4; - optional Auth auth = 5; + optional bytes bin = 2; + optional Auth auth = 3; } diff --git a/skynet/protobuf/skynet_pb2.py b/skynet/protobuf/skynet_pb2.py index dd7db33..84b0527 100644 --- a/skynet/protobuf/skynet_pb2.py +++ b/skynet/protobuf/skynet_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x9c\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_auth\"\x80\x01\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x10\n\x03\x62in\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x1f\n\x04\x61uth\x18\x03 \x01(\x0b\x32\x0c.skynet.AuthH\x01\x88\x01\x01\x42\x06\n\x04_binB\x07\n\x05_authb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals()) @@ -24,9 +24,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _AUTH._serialized_start=54 _AUTH._serialized_end=87 _SKYNETRPCREQUEST._serialized_start=90 - _SKYNETRPCREQUEST._serialized_end=220 - _SKYNETRPCRESPONSE._serialized_start=222 - _SKYNETRPCRESPONSE._serialized_end=324 - _DGPUBUSMESSAGE._serialized_start=327 - _DGPUBUSMESSAGE._serialized_end=468 + _SKYNETRPCREQUEST._serialized_end=246 + _SKYNETRPCRESPONSE._serialized_start=249 + _SKYNETRPCRESPONSE._serialized_end=377 # @@protoc_insertion_point(module_scope) diff --git a/skynet/utils.py b/skynet/utils.py index ba1ce2d..637078b 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import time import random from typing import Optional @@ -21,6 +22,10 @@ from huggingface_hub import login from .constants import ALGOS +def time_ms(): + return int(time.time() * 1000) + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) return Image.fromarray(img) @@ -164,3 +169,13 @@ def upscale( image.save(output) + + +def download_all_models(hf_token: str): + assert torch.cuda.is_available() + + login(token=hf_token) + for model in ALGOS: + print(f'DOWNLOADING {model.upper()}') + pipeline_for(model) + diff --git a/tests/conftest.py b/tests/conftest.py index 64a369f..0b4c335 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,89 +3,30 @@ import os import json import time -import random -import string import logging -from functools import partial from pathlib import Path +from functools import partial -import trio import pytest -import psycopg2 -import trio_asyncio from docker.types import Mount, DeviceRequest -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT -from skynet.constants import * +from skynet.db import open_new_database from skynet.brain import run_skynet +from skynet.network import get_random_port +from skynet.constants import * @pytest.fixture(scope='session') def postgres_db(dockerctl): - rpassword = ''.join( - random.choice(string.ascii_lowercase) - for i in range(12)) - password = ''.join( - random.choice(string.ascii_lowercase) - for i in range(12)) - - with dockerctl.run( - 'postgres', - name='skynet-test-postgres', - ports={'5432/tcp': None}, - environment={ - 'POSTGRES_PASSWORD': rpassword - } - ) as containers: - container = containers[0] - # ip = container.attrs['NetworkSettings']['IPAddress'] - port = container.ports['5432/tcp'][0]['HostPort'] - host = f'localhost:{port}' - - for log in container.logs(stream=True): - log = log.decode().rstrip() - logging.info(log) - if ('database system is ready to accept connections' in log or - 'database system is shut down' in log): - break - - # why print the system is ready to accept connections when its not - # postgres? wtf - time.sleep(1) - logging.info('creating skynet db...') - - conn = psycopg2.connect( - user='postgres', - password=rpassword, - host='localhost', - port=port - ) - logging.info('connected...') - conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - with conn.cursor() as cursor: - cursor.execute( - f'CREATE USER {DB_USER} WITH PASSWORD \'{password}\'') - cursor.execute( - f'CREATE DATABASE {DB_NAME}') - cursor.execute( - f'GRANT ALL PRIVILEGES ON DATABASE {DB_NAME} TO {DB_USER}') - - conn.close() - - logging.info('done.') - yield container, password, host + with open_new_database() as db_params: + yield db_params @pytest.fixture -async def skynet_running(postgres_db): - db_container, db_pass, db_host = postgres_db - - async with run_skynet( - db_pass=db_pass, - db_host=db_host - ): +async def skynet_running(): + async with run_skynet(): yield @@ -99,11 +40,13 @@ def dgpu_workers(request, dockerctl, skynet_running): cmds = [] for i in range(num_containers): + dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}' cmd = f''' pip install -e . && \ skynet run dgpu \ --algos=\'{json.dumps(initial_algos)}\' \ - --uid=dgpu-{i} + --uid=dgpu-{i} \ + --dgpu={dgpu_addr} ''' cmds.append(['bash', '-c', cmd]) @@ -114,16 +57,15 @@ def dgpu_workers(request, dockerctl, skynet_running): name='skynet-test-runtime-cuda', commands=cmds, environment={ - 'HF_TOKEN': os.environ['HF_TOKEN'], 'HF_HOME': '/skynet/hf_home' }, network='host', mounts=mounts, device_requests=devices, - num=num_containers + num=num_containers, ) as containers: yield containers - #for i, container in enumerate(containers): - # logging.info(f'container {i} logs:') - # logging.info(container.logs().decode()) + for i, container in enumerate(containers): + logging.info(f'container {i} logs:') + logging.info(container.logs().decode()) diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 4ce93bf..c187af0 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -12,29 +12,26 @@ from functools import partial import trio import pytest -import trio_asyncio from PIL import Image from google.protobuf.json_format import MessageToDict from skynet.brain import SkynetDGPUComputeError -from skynet.constants import * +from skynet.network import get_random_port, SessionServer +from skynet.protobuf import SkynetRPCResponse from skynet.frontend import open_skynet_rpc +from skynet.constants import * -async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): +async def wait_for_dgpus(session, amount: int, timeout: float = 30.0): gpu_ready = False - start_time = time.time() - current_time = time.time() - while not gpu_ready and (current_time - start_time) < timeout: - res = await rpc('dgpu_workers') - if res.result['ok'] >= amount: - break + with trio.fail_after(timeout): + while not gpu_ready: + res = await session.rpc('dgpu_workers') + if res.result['ok'] >= amount: + break - await trio.sleep(1) - current_time = time.time() - - assert (current_time - start_time) < timeout + await trio.sleep(1) _images = set() @@ -48,34 +45,33 @@ async def check_request_img( ): global _images - async with open_skynet_rpc( + with open_skynet_rpc( uid, - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - res = await rpc_call( - 'txt2img', { - 'prompt': 'red old tractor in a sunny wheat field', - 'step': 28, - 'width': width, 'height': height, - 'guidance': 7.5, - 'seed': None, - 'algo': list(ALGOS.keys())[i], - 'upscaler': upscaler - }) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': width, 'height': height, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[i], + 'upscaler': upscaler + } + }, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - if upscaler == 'x4': - width *= 4 - height *= 4 - - img_raw = zlib.decompress(bytes.fromhex(res.result['img'])) + img_raw = res.bin img_sha = sha256(img_raw).hexdigest() - img = Image.frombytes( - 'RGB', (width, height), img_raw) + img = Image.open(io.BytesIO(img_raw)) if expect_unique and img_sha in _images: raise ValueError('Duplicated image sha: {img_sha}') @@ -96,13 +92,12 @@ async def test_dgpu_worker_compute_error(dgpu_workers): then generate a smaller image to show gpu worker recovery ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) with pytest.raises(SkynetDGPUComputeError) as e: await check_request_img(0, width=4096, height=4096) @@ -112,20 +107,35 @@ async def test_dgpu_worker_compute_error(dgpu_workers): await check_request_img(0) +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_worker(dgpu_workers): + '''Generate one image in a single dgpu worker + ''' + + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) + + await check_request_img(0) + + @pytest.mark.parametrize( 'dgpu_workers', [(1, ['midj', 'stable'])], indirect=True) -async def test_dgpu_workers(dgpu_workers): +async def test_dgpu_worker_two_models(dgpu_workers): '''Generate two images in a single dgpu worker using two different models. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) await check_request_img(0) await check_request_img(1) @@ -138,14 +148,12 @@ async def test_dgpu_worker_upscale(dgpu_workers): two different models. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) - logging.error('UPSCALE') + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) img = await check_request_img(0, upscaler='x4') @@ -157,13 +165,12 @@ async def test_dgpu_worker_upscale(dgpu_workers): async def test_dgpu_workers_two(dgpu_workers): '''Generate two images in two separate dgpu workers ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 2) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 2, timeout=60) async with trio.open_nursery() as n: n.start_soon(check_request_img, 0) @@ -175,13 +182,12 @@ async def test_dgpu_workers_two(dgpu_workers): async def test_dgpu_worker_algo_swap(dgpu_workers): '''Generate an image using a non default model ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) await check_request_img(5) @@ -191,33 +197,32 @@ async def test_dgpu_rotation_next_worker(dgpu_workers): '''Connect three dgpu workers, disconnect and check next_worker rotation happens correctly ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 3) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 3) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 1 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 2 await check_request_img(0) - res = await test_rpc('dgpu_next') + res = await session.rpc('dgpu_next') assert 'ok' in res.result assert res.result['ok'] == 0 @@ -228,13 +233,12 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): '''Connect three dgpu workers, disconnect the first one and check next_worker rotation happens correctly ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 3) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 3) await trio.sleep(3) @@ -245,7 +249,7 @@ async def test_dgpu_rotation_next_worker_disconnect(dgpu_workers): dgpu_workers[0].wait() - res = await test_rpc('dgpu_workers') + res = await session.rpc('dgpu_workers') assert 'ok' in res.result assert res.result['ok'] == 2 @@ -258,26 +262,43 @@ async def test_dgpu_no_ack_node_disconnect(skynet_running): '''Mock a node that connects, gets a request but fails to acknowledge it, then check skynet correctly drops the node ''' - async with open_skynet_rpc( - 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - res = await rpc_call('dgpu_online') - assert 'ok' in res.result + async def mock_rpc(req, ctx): + resp = SkynetRPCResponse() + resp.result.update({'error': 'can\'t do it mate'}) + return resp - await wait_for_dgpus(rpc_call, 1) + dgpu_addr = f'tcp://127.0.0.1:{get_random_port()}' + mock_server = SessionServer( + dgpu_addr, + mock_rpc, + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) - with pytest.raises(SkynetDGPUComputeError) as e: - await check_request_img(0) + async with mock_server.open(): + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: - assert 'dgpu failed to acknowledge request' in str(e) + res = await session.rpc('dgpu_online', { + 'dgpu_addr': dgpu_addr, + 'cert': 'whitelist/testing.cert' + }) + assert 'ok' in res.result - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result - assert res.result['ok'] == 0 + await wait_for_dgpus(session, 1) + + with pytest.raises(SkynetDGPUComputeError) as e: + await check_request_img(0) + + assert 'can\'t do it mate' in str(e.value) + + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result + assert res.result['ok'] == 0 @pytest.mark.parametrize( @@ -286,13 +307,12 @@ async def test_dgpu_timeout_while_processing(dgpu_workers): '''Stop node while processing request to cause timeout and then check skynet correctly drops the node. ''' - async with open_skynet_rpc( + with open_skynet_rpc( 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) async def check_request_img_raises(): with pytest.raises(SkynetDGPUComputeError) as e: @@ -308,72 +328,62 @@ async def test_dgpu_timeout_while_processing(dgpu_workers): assert ec == 0 -@pytest.mark.parametrize( - 'dgpu_workers', [(1, ['midj'])], indirect=True) -async def test_dgpu_heartbeat(dgpu_workers): - ''' - ''' - async with open_skynet_rpc( - 'test-ctx', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as test_rpc: - await wait_for_dgpus(test_rpc, 1) - await trio.sleep(120) - - @pytest.mark.parametrize( 'dgpu_workers', [(1, ['midj'])], indirect=True) async def test_dgpu_img2img(dgpu_workers): - async with open_skynet_rpc( - '1', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: - await wait_for_dgpus(rpc_call, 1) + with open_skynet_rpc( + 'test-ctx', + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: + await wait_for_dgpus(session, 1) + await trio.sleep(2) - res = await rpc_call( - 'txt2img', { - 'prompt': 'red old tractor in a sunny wheat field', - 'step': 28, - 'width': 512, 'height': 512, - 'guidance': 7.5, - 'seed': None, - 'algo': list(ALGOS.keys())[0], - 'upscaler': None - }) + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': 512, 'height': 512, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': None + } + }, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - img_raw = res.result['img'] - img = zlib.decompress(bytes.fromhex(img_raw)) - logging.info(img[:10]) - img = Image.open(io.BytesIO(img)) - + img_raw = res.bin + img = Image.open(io.BytesIO(img_raw)) img.save('txt2img.png') - res = await rpc_call( - 'img2img', { - 'prompt': 'red sports car in a sunny wheat field', - 'step': 28, - 'img': img_raw, - 'guidance': 12, - 'seed': None, - 'algo': list(ALGOS.keys())[0], - 'upscaler': 'x4' - }) + res = await session.rpc( + 'dgpu_call', { + 'method': 'diffuse', + 'params': { + 'prompt': 'red ferrari in a sunny wheat field', + 'step': 28, + 'guidance': 8, + 'strength': 0.7, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': 'x4' + } + }, + binext=img_raw, + timeout=60 + ) if 'error' in res.result: raise SkynetDGPUComputeError(MessageToDict(res.result)) - img_raw = res.result['img'] - img = zlib.decompress(bytes.fromhex(img_raw)) - logging.info(img[:10]) - img = Image.open(io.BytesIO(img)) - + img_raw = res.bin + img = Image.open(io.BytesIO(img_raw)) img.save('img2img.png') diff --git a/tests/test_skynet.py b/tests/test_skynet.py index 5572a70..1587d5d 100644 --- a/tests/test_skynet.py +++ b/tests/test_skynet.py @@ -9,6 +9,7 @@ import trio_asyncio from skynet.brain import run_skynet from skynet.structs import * +from skynet.network import SessionServer from skynet.frontend import open_skynet_rpc @@ -18,53 +19,68 @@ async def test_skynet(skynet_running): async def test_skynet_attempt_insecure(skynet_running): with pytest.raises(pynng.exceptions.NNGException) as e: - async with open_skynet_rpc('bad-actor'): - ... - - assert str(e.value) == 'Connection shutdown' + with open_skynet_rpc('bad-actor') as session: + with trio.fail_after(5): + await session.rpc('skynet_shutdown') async def test_skynet_dgpu_connection_simple(skynet_running): - async with open_skynet_rpc( + + async def rpc_handler(req, ctx): + ... + + fake_dgpu_addr = 'tcp://127.0.0.1:41001' + rpc_server = SessionServer( + fake_dgpu_addr, + rpc_handler, + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) + + with open_skynet_rpc( 'dgpu-0', - security=True, - cert_name='whitelist/testing', - key_name='testing' - ) as rpc_call: + cert_name='whitelist/testing.cert', + key_name='testing.key' + ) as session: # check 0 nodes are connected - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result.keys() assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu_next') - assert 'ok' in res.result + res = await session.rpc('dgpu_next') + assert 'ok' in res.result.keys() assert res.result['ok'] == None - # connect 1 dgpu - res = await rpc_call('dgpu_online') - assert 'ok' in res.result + async with rpc_server.open() as rpc_server: + # connect 1 dgpu + res = await session.rpc( + 'dgpu_online', { + 'dgpu_addr': fake_dgpu_addr, + 'cert': 'whitelist/testing.cert' + }) + assert 'ok' in res.result.keys() - # check 1 node is connected - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result - assert res.result['ok'] == 1 + # check 1 node is connected + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result.keys() + assert res.result['ok'] == 1 - # check next worker is 0 - res = await rpc_call('dgpu_next') - assert 'ok' in res.result - assert res.result['ok'] == 0 + # check next worker is 0 + res = await session.rpc('dgpu_next') + assert 'ok' in res.result.keys() + assert res.result['ok'] == 0 - # disconnect 1 dgpu - res = await rpc_call('dgpu_offline') - assert 'ok' in res.result + # disconnect 1 dgpu + res = await session.rpc('dgpu_offline') + assert 'ok' in res.result.keys() # check 0 nodes are connected - res = await rpc_call('dgpu_workers') - assert 'ok' in res.result + res = await session.rpc('dgpu_workers') + assert 'ok' in res.result.keys() assert res.result['ok'] == 0 # check next worker is None - res = await rpc_call('dgpu_next') - assert 'ok' in res.result + res = await session.rpc('dgpu_next') + assert 'ok' in res.result.keys() assert res.result['ok'] == None diff --git a/tests/test_telegram.py b/tests/test_telegram.py new file mode 100644 index 0000000..d94a6bf --- /dev/null +++ b/tests/test_telegram.py @@ -0,0 +1,28 @@ +#!/usr/bin/python + +import trio + +from functools import partial + +from skynet.db import open_new_database +from skynet.brain import run_skynet +from skynet.config import load_skynet_ini +from skynet.frontend.telegram import run_skynet_telegram + + +if __name__ == '__main__': + '''You will need a telegram bot token configured on skynet.ini for this + ''' + with open_new_database() as db_params: + db_container, db_pass, db_host = db_params + config = load_skynet_ini() + + async def main(): + await run_skynet_telegram( + 'telegram-test', + config['skynet.telegram-test']['token'], + db_host=db_host, + db_pass=db_pass + ) + + trio.run(main)