diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..d09219d --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,67 @@ +name: bench + +on: + pull_request: + types: [ labeled, synchronize ] + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + bench: + if: contains(github.event.pull_request.labels.*.name, 'bench') + + strategy: + fail-fast: false + matrix: + os: [windows-latest, macos-latest] + + runs-on: ${{ matrix.os }} + timeout-minutes: 120 + + steps: + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-build-stable-${{ hashFiles('**/Cargo.toml') }} + + + - name: Install ONNX Runtime on Windows + if: matrix.os == 'windows-latest' + run: | + Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip" + Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP" + echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV + + - name: Install ONNX Runtime on macOS + if: matrix.os == 'macos-latest' + run: | + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" + mkdir -p $HOME/onnxruntime + tar -xzf onnxruntime.tgz -C $HOME/onnxruntime + echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV + + + - name: Set ONNX Runtime library path for macOS + if: matrix.os == 'macos-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV + + - name: Set ONNX Runtime library path for Windows + if: matrix.os == 'windows-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV + + + - name: io benchmark + uses: boa-dev/criterion-compare-action@v3.2.4 + with: + benchName: "modnet" + branchName: ${{ github.base_ref }} + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 28fc69c..d74f255 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ www/assets/ mediamtx/ onnxruntime/ + +*.onnx +*.bin diff --git a/Cargo.toml b/Cargo.toml index 33e65c8..608f7c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.5.0" +version = "0.12.9" edition = "2021" authors = ["mosure "] license = "MIT" @@ -26,21 +26,34 @@ exclude = [ default-run = "modnet" + [features] default = [ + "flame", + "flame_viewer", + "lightglue", "modnet", + "yolo_v8", ] -modnet = [] +flame_viewer = ["bevy_panorbit_camera"] + +flame = [] +lightglue = [] +modnet = ["rayon"] +yolo_v8 = [] [dependencies] bevy_args = "1.3" -image = "0.24" +bevy_panorbit_camera = { version = "0.18", optional = true } +bytemuck = "1.15" +image = "0.24" # upgrade with bevy +include_bytes_aligned = "0.1" ndarray = "0.15" +rayon = { version = "1.8", optional = true } +serde = "1.0" thiserror = "1.0" -tokio = { version = "1.36", features = ["full"] } - [dependencies.bevy] version = "0.13" @@ -48,23 +61,29 @@ default-features = false features = [ "bevy_asset", "bevy_core_pipeline", + "bevy_pbr", "bevy_render", "bevy_ui", "bevy_winit", "multi-threaded", "png", + "tonemapping_luts", ] [dependencies.ort] -version = "2.0.0-alpha.4" +version = "2.0.0-rc.2" default-features = false features = [ - "load-dynamic", + "download-binaries", "ndarray", ] +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + + [profile.dev.package."*"] opt-level = 3 @@ -81,7 +100,41 @@ opt-level = 3 path = "src/lib.rs" +[[bin]] +name = "flame" +path = "tools/flame.rs" +required-features = ["flame", "flame_viewer"] + +[[bin]] +name = "lightglue" +path = "tools/lightglue.rs" +required-features = ["lightglue"] + [[bin]] name = "modnet" path = "tools/modnet.rs" +required-features = ["modnet"] + +[[bin]] +name = "yolo_v8" +path = "tools/yolo_v8.rs" +required-features = ["yolo_v8"] + +[[bench]] +name = "lightglue" +path = "benches/lightglue.rs" +harness = false +required-features = ["lightglue"] + +[[bench]] +name = "modnet" +path = "benches/modnet.rs" +harness = false +required-features = ["modnet"] + +[[bench]] +name = "yolo_v8" +path = "benches/yolo_v8.rs" +harness = false +required-features = ["yolo_v8"] diff --git a/LICENSE-GPL-3.0 b/LICENSE-GPL-3.0 new file mode 100644 index 0000000..be3f7b2 --- /dev/null +++ b/LICENSE-GPL-3.0 @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/LICENSE b/LICENSE-MIT similarity index 100% rename from LICENSE rename to LICENSE-MIT diff --git a/README.md b/README.md index 2f81f48..18e0a75 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,13 @@ # bevy_ort 🪨 [![test](https://github.com/mosure/bevy_ort/workflows/test/badge.svg)](https://github.com/Mosure/bevy_ort/actions?query=workflow%3Atest) -[![GitHub License](https://img.shields.io/github/license/mosure/bevy_ort)](https://raw.githubusercontent.com/mosure/bevy_ort/main/LICENSE) -[![GitHub Last Commit](https://img.shields.io/github/last-commit/mosure/bevy_ort)](https://github.com/mosure/bevy_ort) -[![GitHub Releases](https://img.shields.io/github/v/release/mosure/bevy_ort?include_prereleases&sort=semver)](https://github.com/mosure/bevy_ort/releases) -[![GitHub Issues](https://img.shields.io/github/issues/mosure/bevy_ort)](https://github.com/mosure/bevy_ort/issues) -[![Average time to resolve an issue](https://isitmaintained.com/badge/resolution/mosure/bevy_ort.svg)](http://isitmaintained.com/project/mosure/bevy_ort) +[![GitHub License](https://img.shields.io/badge/license-MIT%2FGPL%E2%80%933.0-blue.svg)](https://github.com/mosure/bevy_ort#license) [![crates.io](https://img.shields.io/crates/v/bevy_ort.svg)](https://crates.io/crates/bevy_ort) a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library -![person](assets/person.png) -![mask](assets/mask.png) +![person](assets/images/person.png) +![mask](assets/images/mask.png) *> modnet inference example* @@ -24,6 +20,12 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library - [X] batched modnet preprocessing - [X] compute task pool inference scheduling +### models +- [X] lightglue (feature matching) +- [X] modnet (photographic portrait matting) +- [X] yolo_v8 (object detection) +- [X] flame (parametric head model) + ## library usage @@ -32,12 +34,12 @@ use bevy::prelude::*; use bevy_ort::{ BevyOrtPlugin, - inputs, - models::modnet::{ - images_to_modnet_input, - modnet_output_to_luma_images, + models::flame::{ + FlameInput, + FlameOutput, + Flame, + FlamePlugin, }, - Onnx, }; @@ -46,95 +48,51 @@ fn main() { .add_plugins(( DefaultPlugins, BevyOrtPlugin, + FlamePlugin, )) - .init_resource::() - .add_systems(Startup, load_modnet) - .add_systems(Update, inference) + .add_systems(Startup, load_flame) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) .run(); } -#[derive(Resource, Default)] -pub struct Modnet { - pub onnx: Handle, - pub input: Handle, -} -fn load_modnet( +fn load_flame( asset_server: Res, - mut modnet: ResMut, + mut flame: ResMut, ) { - let modnet_handle: Handle = asset_server.load("modnet_photographic_portrait_matting.onnx"); - modnet.onnx = modnet_handle; + flame.onnx = asset_server.load("models/flame.onnx"); +} + - let input_handle: Handle = asset_server.load("person.png"); - modnet.input = input_handle; +fn setup( + mut commands: Commands, +) { + commands.spawn(FlameInput::default()); + commands.spawn(Camera3dBundle::default()); } -fn inference( +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( mut commands: Commands, - modnet: Res, - onnx_assets: Res>, - mut images: ResMut>, - mut complete: Local, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, ) { - if *complete { - return; - } + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); - let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(vec![&image]); - - let mask_image: Result = (|| { - let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; - let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; - let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; - - let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string())?; - let outputs = session.run(input_values).map_err(|e| e.to_string()); - - let binding = outputs.ok().unwrap(); - let output_value: &ort::Value = binding.get("output").unwrap(); - - Ok(modnet_output_to_luma_images(output_value).pop().unwrap()) - })(); - - match mask_image { - Ok(mask_image) => { - let mask_image = images.add(mask_image); - - commands.spawn(NodeBundle { - style: Style { - display: Display::Grid, - width: Val::Percent(100.0), - height: Val::Percent(100.0), - grid_template_columns: RepeatedGridTrack::flex(1, 1.0), - grid_template_rows: RepeatedGridTrack::flex(1, 1.0), - ..default() - }, - background_color: BackgroundColor(Color::DARK_GRAY), - ..default() - }) - .with_children(|builder| { - builder.spawn(ImageBundle { - style: Style { - ..default() - }, - image: UiImage::new(mask_image.clone()), - ..default() - }); - }); - - commands.spawn(Camera2dBundle::default()); - - *complete = true; - } - Err(e) => { - println!("inference failed: {}", e); - } + println!("{:?}", flame_output); } } - ``` @@ -162,3 +120,14 @@ use an accelerated execution provider: ## credits - [modnet](https://github.com/ZHKKKe/MODNet) + + +## license + +This software is dual-licensed under the MIT License and the GNU General Public License version 3 (GPL-3.0). + +You may choose to use this software under the terms of the MIT License OR the GNU General Public License version 3 (GPL-3.0), except as stipulated below: + +The use of the `yolo_v8` feature within this software is specifically governed by the GNU General Public License version 3 (GPL-3.0). By using the `yolo_v8` feature, you agree to comply with the terms and conditions of the GPL-3.0. + +For more details on the licenses, please refer to the LICENSE.MIT and LICENSE.GPL-3.0 files included with this software. diff --git a/assets/mask.png b/assets/images/mask.png similarity index 100% rename from assets/mask.png rename to assets/images/mask.png diff --git a/assets/person.png b/assets/images/person.png similarity index 100% rename from assets/person.png rename to assets/images/person.png diff --git a/assets/images/sacre_coeur1.png b/assets/images/sacre_coeur1.png new file mode 100644 index 0000000..c8bb07f Binary files /dev/null and b/assets/images/sacre_coeur1.png differ diff --git a/assets/images/sacre_coeur2.png b/assets/images/sacre_coeur2.png new file mode 100644 index 0000000..2cd68f7 Binary files /dev/null and b/assets/images/sacre_coeur2.png differ diff --git a/assets/models/disk_lightglue_end2end_fused.onnx b/assets/models/disk_lightglue_end2end_fused.onnx new file mode 100644 index 0000000..b97fd6e Binary files /dev/null and b/assets/models/disk_lightglue_end2end_fused.onnx differ diff --git a/assets/models/disk_lightglue_end2end_fused_cpu.onnx b/assets/models/disk_lightglue_end2end_fused_cpu.onnx new file mode 100644 index 0000000..6b988a3 Binary files /dev/null and b/assets/models/disk_lightglue_end2end_fused_cpu.onnx differ diff --git a/assets/modnet_photographic_portrait_matting.onnx b/assets/models/modnet_photographic_portrait_matting.onnx similarity index 100% rename from assets/modnet_photographic_portrait_matting.onnx rename to assets/models/modnet_photographic_portrait_matting.onnx diff --git a/assets/models/yolov8n.onnx b/assets/models/yolov8n.onnx new file mode 100644 index 0000000..cfb621b Binary files /dev/null and b/assets/models/yolov8n.onnx differ diff --git a/benches/modnet.rs b/benches/modnet.rs new file mode 100644 index 0000000..60e13df --- /dev/null +++ b/benches/modnet.rs @@ -0,0 +1,151 @@ +use criterion::{ + BenchmarkId, + criterion_group, + criterion_main, + Criterion, + Throughput, +}; + +use bevy::{ + prelude::*, + render::{ + render_asset::RenderAssetUsages, + render_resource::{ + Extent3d, + TextureDimension, + }, + }, +}; +use bevy_ort::{ + inputs, + models::modnet::{ + modnet_inference, + modnet_output_to_luma_images, + images_to_modnet_input, + }, + OrtSession, + Session, +}; +use ort::GraphOptimizationLevel; + + +const MAX_RESOLUTIONS: [(u32, u32); 4] = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), +]; + +const STREAM_COUNT: usize = 16; + + +criterion_group!{ + name = modnet_benches; + config = Criterion::default().sample_size(10); + targets = images_to_modnet_input_benchmark, + modnet_output_to_luma_images_benchmark, + modnet_inference_benchmark, +} +criterion_main!(modnet_benches); + + +fn images_to_modnet_input_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("images_to_modnet_input"); + + MAX_RESOLUTIONS.iter() + .for_each(|(width, height)| { + let data = vec![0u8; (1920 * 1080 * 4) as usize]; + + let images = (0..STREAM_COUNT) + .map(|_|{ + Image::new( + Extent3d { + width: 1920, + height: 1080, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ) + }) + .collect::>(); + + group.throughput(Throughput::Elements(STREAM_COUNT as u64)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &images, |b, images| { + let views = images.iter().map(|image| image).collect::>(); + + b.iter(|| images_to_modnet_input(views.as_slice(), Some((*width, *height)))); + }); + }); +} + + +fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modnet_output_to_luma_images"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + + let data = vec![0u8; (1920 * 1080 * 4) as usize]; + let image: Image = Image::new( + Extent3d { + width: 1920, + height: 1080, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + MAX_RESOLUTIONS.iter() + .for_each(|size_limit| { + let input = images_to_modnet_input(&[ℑ STREAM_COUNT], size_limit.clone().into()); + let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string()).unwrap(); + + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output").unwrap(); + + group.throughput(Throughput::Elements(STREAM_COUNT as u64)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", size_limit.0, size_limit.1)), &output_value, |b, output_value| { + b.iter(|| modnet_output_to_luma_images(output_value)); + }); + }); +} + + +fn modnet_inference_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modnet_inference"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + let session: bevy_ort::OrtSession = OrtSession::Session(session); + + MAX_RESOLUTIONS.iter().for_each(|(width, height)| { + let data = vec![0u8; *width as usize * *height as usize * 4]; + let image = Image::new( + Extent3d { + width: *width, + height: *height, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &(width, height), |b, _| { + b.iter(|| { + modnet_inference(&session, &[&image], Some((*width, *height))) + }); + }); + }); +} diff --git a/benches/yolo_v8.rs b/benches/yolo_v8.rs new file mode 100644 index 0000000..dbc29f0 --- /dev/null +++ b/benches/yolo_v8.rs @@ -0,0 +1,145 @@ +use criterion::{ + BenchmarkId, + criterion_group, + criterion_main, + Criterion, + Throughput, +}; + +use bevy::{ + prelude::*, + render::{ + render_asset::RenderAssetUsages, + render_resource::{ + Extent3d, + TextureDimension, + }, + }, +}; +use bevy_ort::{ + inputs, + models::yolo_v8::{ + prepare_input, + process_output, + yolo_inference, + }, + OrtSession, + Session, +}; +use ort::GraphOptimizationLevel; + + +criterion_group!{ + name = yolo_v8_benches; + config = Criterion::default().sample_size(10); + targets = prepare_input_benchmark, + process_output_benchmark, + inference_benchmark, +} +criterion_main!(yolo_v8_benches); + + +const RESOLUTIONS: [(u32, u32); 3] = [ + (640, 360), + (1280, 720), + (1920, 1080), +]; + +// TODO: read input shape from session +const MODEL_WIDTH: u32 = 640; +const MODEL_HEIGHT: u32 = 640; + + +fn prepare_input_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("yolo_v8_prepare_input"); + + RESOLUTIONS.iter() + .for_each(|(width, height)| { + let data = vec![0u8; (width * height * 4) as usize]; + let image = Image::new( + Extent3d { + width: *width, + height: *height, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &image, |b, image| { + b.iter(|| prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT)); + }); + }); +} + + +fn process_output_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("yolo_v8_process_output"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .commit_from_file("assets/yolov8n.onnx").unwrap(); + + RESOLUTIONS.iter() + .for_each(|(width, height)| { + let data = vec![0u8; (width * height * 4) as usize]; + let image: Image = Image::new( + Extent3d { + width: *width, + height: *height, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + let input = prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT); + let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap(); + + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output0").unwrap(); + + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &output_value, |b, output_value| { + b.iter(|| process_output(output_value, *width, *height, MODEL_WIDTH, MODEL_HEIGHT)); + }); + }); +} + + +fn inference_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("yolo_v8_inference"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .commit_from_file("assets/yolov8n.onnx").unwrap(); + let session = OrtSession::Session(session); + + RESOLUTIONS.iter().for_each(|(width, height)| { + let data = vec![0u8; *width as usize * *height as usize * 4]; + let image = Image::new( + Extent3d { + width: *width, + height: *height, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &(width, height), |b, _| { + b.iter(|| { + yolo_inference(&session, &image, 0.5) + }); + }); + }); +} diff --git a/src/lib.rs b/src/lib.rs index 62038ed..09935c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,9 +51,57 @@ impl Plugin for BevyOrtPlugin { } -#[derive(Asset, Debug, Default, TypePath)] +pub enum OrtSession { + Session(ort::Session), + InMemory(ort::InMemorySession<'static>), +} + +impl OrtSession { + pub fn run<'s, 'i, 'v: 'i, const N: usize>( + &'s self, + input_values: impl Into>, + ) -> Result { + match self { + OrtSession::Session(session) => session.run(input_values), + OrtSession::InMemory(session) => session.run(input_values), + } + } + + pub fn inputs(&self) -> &Vec { + match self { + OrtSession::Session(session) => &session.inputs, + OrtSession::InMemory(session) => &session.inputs, + } + } + + pub fn outputs(&self) -> &Vec { + match self { + OrtSession::Session(session) => &session.outputs, + OrtSession::InMemory(session) => &session.outputs, + } + } +} + +#[derive(Asset, Default, TypePath)] pub struct Onnx { - pub session: Arc>>, + pub session_data: Vec, + pub session: Arc>>, +} + +impl Onnx { + pub fn from_session(session: Session) -> Self { + Self { + session_data: Vec::new(), + session: Arc::new(Mutex::new(Some(OrtSession::Session(session)))), + } + } + + pub fn from_in_memory(session: ort::InMemorySession<'static>) -> Self { + Self { + session_data: Vec::new(), + session: Arc::new(Mutex::new(Some(OrtSession::InMemory(session)))), + } + } } @@ -88,11 +136,9 @@ impl AssetLoader for OnnxLoader { // TODO: add session configuration let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_memory(&bytes)?; + .commit_from_memory(&bytes)?; - Ok(Onnx { - session: Arc::new(Mutex::new(Some(session))), - }) + Ok(Onnx::from_session(session)) }, _ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))), } diff --git a/src/models/flame.rs b/src/models/flame.rs new file mode 100644 index 0000000..f366552 --- /dev/null +++ b/src/models/flame.rs @@ -0,0 +1,237 @@ +use bevy::{ + prelude::*, + render::{ + mesh::{ + Indices, + Mesh, + Meshable, + PrimitiveTopology, + }, + render_asset::RenderAssetUsages, + }, +}; +use bytemuck::cast_slice; +use include_bytes_aligned::include_bytes_aligned; +use ndarray::Array2; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs, + Onnx, + OrtSession, +}; + + +pub static INDEX_BUFFER: &[u8] = include_bytes_aligned!(4, "flame_index_buffer.bin"); + + +pub struct FlamePlugin; +impl Plugin for FlamePlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + app.add_systems(PreUpdate, flame_inference_system); + } +} + + + +#[derive(Resource, Default)] +pub struct Flame { + pub onnx: Handle, +} + + +fn flame_inference_system( + mut commands: Commands, + flame: Res, + onnx_assets: Res>, + flame_inputs: Query< + ( + Entity, + &FlameInput, + ), + Without, + >, +) { + for (entity, flame_input) in flame_inputs.iter() { + let flame_output: Result = (|| { + let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get flame ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get flame session from flame ONNX asset")?; + + Ok(flame_inference( + session, + flame_input, + )) + })(); + + match flame_output { + Ok(flame_output) => { + commands.entity(entity) + .insert(flame_output); + } + Err(_e) => { + return; + } + } + } +} + + +const FLAME_BATCH_SIZE: usize = 1; + +#[derive( + Debug, + Clone, + Component, + Reflect, +)] +pub struct FlameInput { + pub shape: [[f32; 100]; FLAME_BATCH_SIZE], + pub pose: [[f32; 6]; FLAME_BATCH_SIZE], + pub expression: [[f32; 50]; FLAME_BATCH_SIZE], + pub neck: [[f32; 3]; FLAME_BATCH_SIZE], + pub eye: [[f32; 6]; FLAME_BATCH_SIZE], +} + +impl Default for FlameInput { + fn default() -> Self { + Self { + shape: [[0.0; 100]; FLAME_BATCH_SIZE], + pose: [[0.0; 6]; FLAME_BATCH_SIZE], + expression: [[0.0; 50]; FLAME_BATCH_SIZE], + neck: [[0.0; 3]; FLAME_BATCH_SIZE], + eye: [[0.0; 6]; FLAME_BATCH_SIZE], + } + } +} + + +#[derive( + Debug, + Clone, + Component, + Deserialize, + Serialize, + Reflect, +)] +pub struct FlameOutput { + pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding + // pub landmarks: Vec<[f32; 3]>, +} + +impl Default for FlameOutput { + fn default() -> Self { + Self { + vertices: vec![[0.0; 3]; 5023], + // landmarks: vec![[0.0; 3]; 68], + } + } +} + +impl Meshable for FlameOutput { + type Output = Mesh; + + fn mesh(&self) -> Self::Output { + let indices = Indices::U32(cast_slice(INDEX_BUFFER).to_vec()); + + Mesh::new( + PrimitiveTopology::TriangleList, + RenderAssetUsages::default(), + ) + .with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, self.vertices.clone()) + .with_inserted_indices(indices) + } +} + + +pub fn flame_inference( + session: &OrtSession, + input: &FlameInput, +) -> FlameOutput { + let PreparedInput { + shape, + expression, + pose, + neck, + eye, + } = prepare_input(input); + + let input_values = inputs![ + "shape" => shape.view(), + "expression" => expression.view(), + "pose" => pose.view(), + "neck" => neck.view(), + "eye" => eye.view(), + ].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + + let vertices: &ort::Value = binding.get("vertices").unwrap(); + // let landmarks: &ort::Value = binding.get("landmarks").unwrap(); + + post_process( + vertices, + // landmarks, + ) +} + + +pub struct PreparedInput { + pub shape: Array2, + pub pose: Array2, + pub expression: Array2, + pub neck: Array2, + pub eye: Array2, +} + +pub fn prepare_input( + input: &FlameInput, +) -> PreparedInput { + let shape = Array2::from_shape_vec((FLAME_BATCH_SIZE, 100), input.shape.concat()).unwrap(); + let pose = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.pose.concat()).unwrap(); + let expression = Array2::from_shape_vec((FLAME_BATCH_SIZE, 50), input.expression.concat()).unwrap(); + let neck = Array2::from_shape_vec((FLAME_BATCH_SIZE, 3), input.neck.concat()).unwrap(); + let eye = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.eye.concat()).unwrap(); + + PreparedInput { + shape, + expression, + pose, + neck, + eye, + } +} + + +pub fn post_process( + vertices: &ort::Value, + // landmarks: &ort::Value, +) -> FlameOutput { + let vertices_tensor = vertices.try_extract_tensor::().unwrap(); + let vertices_view = vertices_tensor.view(); // [FLAME_BATCH_SIZE, 5023, 3] + + // let landmarks_tensor = landmarks.try_extract_tensor::().unwrap(); + // let landmarks_view = landmarks_tensor.view(); // [FLAME_BATCH_SIZE, 68, 3] + + let vertices = vertices_view.outer_iter() + .flat_map(|subtensor| { + subtensor.outer_iter().map(|row| { + [row[0], row[1], row[2]] + }).collect::>() + }) + .collect::>(); + + // let landmarks = landmarks_view.outer_iter() + // .flat_map(|subtensor| { + // subtensor.outer_iter().map(|row| { + // [row[0], row[1], row[2]] + // }).collect::>() + // }) + // .collect::>(); + + FlameOutput { + vertices, + // landmarks, + } +} diff --git a/src/models/flame_index_buffer.bin b/src/models/flame_index_buffer.bin new file mode 100644 index 0000000..fef98f5 Binary files /dev/null and b/src/models/flame_index_buffer.bin differ diff --git a/src/models/lightglue.rs b/src/models/lightglue.rs new file mode 100644 index 0000000..6180455 --- /dev/null +++ b/src/models/lightglue.rs @@ -0,0 +1,137 @@ +use bevy::prelude::*; +use image::GenericImageView; +use ndarray::{Array, ArrayD, Axis}; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs, + Onnx, + OrtSession, +}; + + + +pub struct LightgluePlugin; +impl Plugin for LightgluePlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + } +} + +#[derive(Resource, Default)] +pub struct Lightglue { + pub onnx: Handle, +} + + +#[derive(Debug, Default, Clone, Deserialize, Serialize)] +pub struct GluedPair { + pub from_x: i64, + pub from_y: i64, + pub to_x: i64, + pub to_y: i64, +} + + +pub fn lightglue_inference( + session: &OrtSession, + images: &[&Image], +) -> Vec<(usize, usize, Vec)> { + let unique_unordered_pairs = images.iter().enumerate() + .flat_map(|(i, _)| { + images.iter().enumerate().skip(i + 1).map(move |(j, _)| (i, j)) + }) + .collect::>(); + + unique_unordered_pairs.iter() + .map(|(i, j)| { + let a = images[*i]; + let b = images[*j]; + + let prepared_a = prepare_input(a); + let prepared_b = prepare_input(b); + + let input_values = inputs![ + "image0" => prepared_a.view(), + "image1" => prepared_b.view(), + ].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + + let kpts0: &ort::Value = binding.get("kpts0").unwrap(); + let kpts1: &ort::Value = binding.get("kpts1").unwrap(); + let matches0: &ort::Value = binding.get("matches0").unwrap(); + + ( + *i, + *j, + post_process( + kpts0, + kpts1, + matches0, + ).unwrap(), + ) + }) + .collect::>() +} + + +pub fn prepare_input( + image: &Image, +) -> ArrayD { + let image = &image.clone().try_into_dynamic().unwrap(); + + let mut input = Array::zeros((1, 3, image.height() as usize, image.width() as usize)).into_dyn(); + + image.pixels().for_each(|(x, y, pixel)| { + let [r, g, b, _] = pixel.0; + let (x, y) = (x as usize, y as usize); + + input[[0, 0, y, x]] = r as f32 / 255.0; + input[[0, 1, y, x]] = g as f32 / 255.0; + input[[0, 2, y, x]] = b as f32 / 255.0; + }); + + input +} + + +pub fn post_process( + kpts0: &ort::Value, + kpts1: &ort::Value, + matches: &ort::Value, +) -> Result, &'static str> { + let kpts0_tensor = kpts0.try_extract_tensor::().unwrap(); + let kpts0_view = kpts0_tensor.view(); + + let kpts1_tensor = kpts1.try_extract_tensor::().unwrap(); + let kpts1_view = kpts1_tensor.view(); + + let matches = matches.try_extract_tensor::().unwrap(); + let matches_view = matches.view(); + + Ok( + matches_view.axis_iter(Axis(0)) + .map(|row| { + let kpts0_idx = row[0]; + let kpts1_idx = row[1]; + + let kpt0 = kpts0_view.index_axis(Axis(1), kpts0_idx as usize); + + let kpt0_x = kpt0[[0, 0]]; + let kpt0_y = kpt0[[0, 1]]; + + let kpt1 = kpts1_view.index_axis(Axis(1), kpts1_idx as usize); + let kpt1_x = kpt1[[0, 0]]; + let kpt1_y = kpt1[[0, 1]]; + + GluedPair { + from_x: kpt0_x, + from_y: kpt0_y, + to_x: kpt1_x, + to_y: kpt1_y, + } + }) + .collect::>() + ) +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 59367d2..6c07ed6 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,2 +1,11 @@ +#[cfg(feature = "flame")] +pub mod flame; + +#[cfg(feature = "lightglue")] +pub mod lightglue; + #[cfg(feature = "modnet")] pub mod modnet; + +#[cfg(feature = "yolo_v8")] +pub mod yolo_v8; diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 7721100..db4a0ec 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -1,13 +1,59 @@ -use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; +use bevy::{ + prelude::*, + render::{ + render_asset::RenderAssetUsages, + render_resource::{ + Extent3d, + TextureDimension, + TextureFormat, + }, + }, +}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; -use ndarray::{Array, Array4, ArrayView4, Axis, s}; +use ndarray::{Array, Array4, ArrayView4}; +use rayon::prelude::*; + +use crate::{ + inputs, + Onnx, + OrtSession, +}; + + + +pub struct ModnetPlugin; +impl Plugin for ModnetPlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + } +} + +#[derive(Resource, Default)] +pub struct Modnet { + pub onnx: Handle, +} + + +pub fn modnet_inference( + session: &OrtSession, + images: &[&Image], + max_size: Option<(u32, u32)>, +) -> Vec { + let input = images_to_modnet_input(images, max_size); + + let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output").unwrap(); + + modnet_output_to_luma_images(output_value) +} pub fn modnet_output_to_luma_images( output_value: &ort::Value, ) -> Vec { - let tensor: ort::Tensor = output_value.extract_tensor::().unwrap(); - + let tensor = output_value.try_extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); @@ -18,30 +64,37 @@ pub fn modnet_output_to_luma_images( let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap()) .expect("failed to create ArrayView4 from shape and data"); - let mut images = Vec::new(); - - for i in 0..batch_size { - let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); - - for y in 0..height { - for x in 0..width { - let pixel_value = tensor_data[(i, 0, y, x)]; - let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; - imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + (0..batch_size) + .into_par_iter() + .map(|i| { + let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); + + for y in 0..height { + for x in 0..width { + let pixel_value = tensor_data[(i, 0, y, x)]; + let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; + imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + } } - } - - let dyn_img = DynamicImage::ImageLuma8(imgbuf); - - images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all())); - } - images + Image::new( + Extent3d { + width: width as u32, + height: height as u32, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + imgbuf.into_raw(), + TextureFormat::R8Unorm, + RenderAssetUsages::all(), + ) + }) + .collect::>() } pub fn images_to_modnet_input( - images: Vec<&Image>, + images: &[&Image], max_size: Option<(u32, u32)>, ) -> Array4 { if images.is_empty() { @@ -51,58 +104,45 @@ pub fn images_to_modnet_input( let ref_size = 512; let &first_image = images.first().unwrap(); - let image = first_image.to_owned(); - - let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); - let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); - let first_image_ndarray = image_to_ndarray(&resized_image); - - let single_image_shape = first_image_ndarray.dim(); - let n_images = images.len(); - let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3); + let (x_scale, y_scale) = get_scale_factor(first_image.height(), first_image.width(), ref_size, max_size); - let mut aggregate = Array4::::zeros(batch_shape); - - for (i, &image) in images.iter().enumerate() { - let image = image.to_owned(); - let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); - let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); - let image_ndarray = image_to_ndarray(&resized_image); + let processed_images: Vec> = images + .par_iter() + .map(|&image| { + let resized_image = resize_image(&image.clone().try_into_dynamic().unwrap(), x_scale, y_scale); + image_to_ndarray(&resized_image) + }) + .collect(); - let slice = s![i, .., .., ..]; - aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0)); - } + let aggregate = Array::from_shape_vec( + (processed_images.len(), processed_images[0].shape()[1], processed_images[0].shape()[2], processed_images[0].shape()[3]), + processed_images.iter().flat_map(|a| a.iter().cloned()).collect(), + ).unwrap(); aggregate } fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) { - // Calculate the scale factor based on the maximum size constraints let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| { f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32) }); - // Calculate the target dimensions after applying the max scale factor (clipping to max_size) let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32); - // Calculate the scale factor to fit within the reference size, considering the target dimensions let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size { let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32; (scale_factor, scale_factor) } else { - (1.0, 1.0) // Do not upscale if target dimensions are within reference size + (1.0, 1.0) }; - // Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w); let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h); - // Adjust dimensions to ensure they are multiples of 32 let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32; let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32; - // Return the scale factors based on the original image dimensions (final_w as f32 / im_w as f32, final_h as f32 / im_h as f32) } @@ -110,21 +150,18 @@ fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, fn image_to_ndarray(img: &RgbImage) -> Array4 { let (width, height) = img.dimensions(); - // convert RgbImage to a Vec of f32 values normalized to [-1, 1] - let raw: Vec = img.pixels() - .flat_map(|p| { - p.0.iter().map(|&e| { - (e as f32 - 127.5) / 127.5 - }) - }) - .collect(); - - // create a 3D array from the raw pixel data - let arr = Array::from_shape_vec((height as usize, width as usize, 3), raw) - .expect("failed to create ndarray from image raw data"); + let arr = Array::from_shape_fn((1, 3, height as usize, width as usize), |(_, c, y, x)| { + let pixel = img.get_pixel(x as u32, y as u32); + let channel_value = match c { + 0 => pixel[0], + 1 => pixel[1], + 2 => pixel[2], + _ => unreachable!(), + }; + (channel_value as f32 - 127.5) / 127.5 + }); - // rearrange the dimensions from [height, width, channels] to [1, channels, height, width] - arr.permuted_axes([2, 0, 1]).insert_axis(Axis(0)) + arr } fn resize_image(image: &DynamicImage, x_scale: f32, y_scale: f32) -> RgbImage { diff --git a/src/models/yolo_v8.rs b/src/models/yolo_v8.rs new file mode 100644 index 0000000..e9c5164 --- /dev/null +++ b/src/models/yolo_v8.rs @@ -0,0 +1,398 @@ +use bevy::prelude::*; +use image::GenericImageView; +use ndarray::{Array, ArrayD, Axis}; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs, + Onnx, + OrtSession, +}; + + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct BoundingBox { + pub x1: f32, + pub y1: f32, + pub x2: f32, + pub y2: f32, + pub class_id: usize, + pub prob: f32, +} + + +pub struct YoloPlugin; +impl Plugin for YoloPlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + } +} + +#[derive(Resource, Default)] +pub struct Yolo { + pub onnx: Handle, +} + + +// TODO: support yolo input batching +pub fn yolo_inference( + session: &OrtSession, + image: &Image, + iou_threshold: f32, +) -> Vec { + let width = image.width(); + let height = image.height(); + + let model_width = session.inputs()[0].input_type.tensor_dimensions().unwrap()[2] as u32; + let model_height = session.inputs()[0].input_type.tensor_dimensions().unwrap()[3] as u32; + + let input = prepare_input(image, model_width, model_height); + + let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output0").unwrap(); + + let detections = process_output(output_value, width, height, model_width, model_height); + + nms(&detections, iou_threshold) +} + + +pub fn prepare_input( + image: &Image, + model_width: u32, + model_height: u32, +) -> ArrayD { + let image = &image.clone().try_into_dynamic().unwrap(); + let image = image.resize_exact(model_width, model_height, image::imageops::FilterType::CatmullRom); + + let mut input = Array::zeros((1, 3, model_width as usize, model_height as usize)).into_dyn(); + + image.pixels().for_each(|(x, y, pixel)| { + let [r, g, b, _] = pixel.0; + let (x, y) = (x as usize, y as usize); + + input[[0, 0, y, x]] = r as f32 / 255.0; + input[[0, 1, y, x]] = g as f32 / 255.0; + input[[0, 2, y, x]] = b as f32 / 255.0; + }); + + input +} + + +pub fn process_output( + output: &ort::Value, + width: u32, + height: u32, + model_width: u32, + model_height: u32, +) -> Vec { + let mut boxes = Vec::new(); + + let tensor = output.try_extract_tensor::().unwrap(); + let data = tensor.view().t().into_owned(); + + for detection in data.axis_iter(Axis(0)) { + let detection : Vec<_> = detection.iter().collect(); + + let (class_id, prob) = detection.iter() + .skip(4) + .enumerate() + .reduce(|acc, row| if row.1 > acc.1 { row } else { acc }) + .unwrap(); + + if **prob < 0.5 { + continue; + } + + let xc = detection[0] / model_width as f32 * width as f32; + let yc = detection[1] / model_height as f32 * height as f32; + let w = detection[2] / model_width as f32 * width as f32; + let h = detection[3] / model_height as f32 * height as f32; + + let x1 = (xc - w / 2.0).max(0.0); + let y1 = (yc - h / 2.0).max(0.0); + let x2 = (xc + w / 2.0).min(width as f32); + let y2 = (yc + h / 2.0).min(height as f32); + + boxes.push(BoundingBox { + x1, + y1, + x2, + y2, + class_id, + prob: **prob, + }); + } + + boxes +} + + +pub fn nms( + input: &[BoundingBox], + iou_threshold: f32, +) -> Vec { + let mut output: Vec = Vec::new(); + + let mut boxes_by_class = std::collections::HashMap::new(); + for bbox in input { + boxes_by_class.entry(bbox.class_id) + .or_insert_with(Vec::new) + .push(bbox.clone()); + } + + for (_class_id, mut boxes) in boxes_by_class { + boxes.sort_by(|a, b| b.prob.partial_cmp(&a.prob).unwrap_or(std::cmp::Ordering::Equal)); + + while !boxes.is_empty() { + let highest = boxes.remove(0); + output.push(highest.clone()); + + boxes.retain(|bbox| iou(&highest, bbox) < iou_threshold); + } + } + + output +} + +fn iou(box_a: &BoundingBox, box_b: &BoundingBox) -> f32 { + let intersection_x1 = box_a.x1.max(box_b.x1); + let intersection_y1 = box_a.y1.max(box_b.y1); + let intersection_x2 = box_a.x2.min(box_b.x2); + let intersection_y2 = box_a.y2.min(box_b.y2); + + let intersection_area = 0f32.max(intersection_x2 - intersection_x1) * + 0f32.max(intersection_y2 - intersection_y1); + + let box_a_area = (box_a.x2 - box_a.x1) * (box_a.y2 - box_a.y1); + let box_b_area = (box_b.x2 - box_b.x1) * (box_b.y2 - box_b.y1); + + let union_area = box_a_area + box_b_area - intersection_area; + + if union_area == 0.0 { + return 0.0; + } + + intersection_area / union_area +} + + +pub const YOLO_CLASSES: [&str; 80] = [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +]; + + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_non_overlapping_boxes() { + let a = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 1.0, + y2: 1.0, + class_id: 0, + prob: 0.9, + }; + + let b = BoundingBox { + x1: 2.0, + y1: 2.0, + x2: 3.0, + y2: 3.0, + class_id: 0, + prob: 0.8, + }; + + let filtered_boxes = nms(&[a.clone(), b.clone()], 0.5); + assert_eq!(filtered_boxes.len(), 2, "both boxes should be retained as they do not overlap."); + + assert_eq!(iou(&a, &b), 0.0, "the boxes do not overlap, so the IoU should be 0."); + } + + #[test] + fn test_overlapping_boxes_same_class() { + let a = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 2.0, + y2: 2.0, + class_id: 0, + prob: 0.9, + }; + + let b = BoundingBox { + x1: 1.0, + y1: 1.0, + x2: 3.0, + y2: 3.0, + class_id: 0, + prob: 0.8, + }; + + let expected_iou = 1.0 / 7.0; + + let filtered_boxes = nms(&[a.clone(), b.clone()], expected_iou - 0.1); + assert_eq!(filtered_boxes.len(), 1, "only one box should be retained due to overlap."); + assert_eq!(filtered_boxes[0].prob, 0.9, "the box with the higher probability should be retained."); + + assert_eq!((iou(&a, &b) - expected_iou).abs() < 1e-6, true, "the IoU should be approximately 1/7."); + } + + #[test] + fn test_overlapping_boxes_different_classes() { + let a = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 2.0, + y2: 2.0, + class_id: 0, + prob: 0.9, + }; + + let b = BoundingBox { + x1: 1.0, + y1: 1.0, + x2: 3.0, + y2: 3.0, + class_id: 1, + prob: 0.8, + }; + + let filtered_boxes = nms(&[a, b], 0.5); + assert_eq!(filtered_boxes.len(), 2, "both boxes should be retained as they belong to different classes."); + } + + #[test] + fn test_iou_complete_overlap() { + let a = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 2.0, + y2: 2.0, + class_id: 0, + prob: 0.9, + }; + + let b = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 2.0, + y2: 2.0, + class_id: 0, + prob: 0.8, + }; + + let expected_iou = 1.0; + assert_eq!((iou(&a, &b) - expected_iou).abs() < 1e-6, true, "the IoU should be 1.0."); + } + + #[test] + fn test_iou_overlap_edge_case() { + let a = BoundingBox { + x1: 0.0, + y1: 0.0, + x2: 2.0, + y2: 2.0, + class_id: 0, + prob: 0.9, + }; + + let b = BoundingBox { + x1: 2.0, + y1: 2.0, + x2: 4.0, + y2: 4.0, + class_id: 0, + prob: 0.8, + }; + + let expected_iou = 0.0; + assert_eq!((iou(&a, &b) - expected_iou).abs() < 1e-6, true, "the IoU should be 0.0."); + } +} diff --git a/tools/flame.rs b/tools/flame.rs new file mode 100644 index 0000000..815521f --- /dev/null +++ b/tools/flame.rs @@ -0,0 +1,78 @@ +use bevy::prelude::*; +use bevy_panorbit_camera::{ + PanOrbitCamera, + PanOrbitCameraPlugin, +}; + +use bevy_ort::{ + BevyOrtPlugin, + models::flame::{ + FlameInput, + FlameOutput, + Flame, + FlamePlugin, + }, +}; + + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + BevyOrtPlugin, + FlamePlugin, + PanOrbitCameraPlugin, + )) + .add_systems(Startup, load_flame) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) + .run(); +} + + +fn load_flame( + asset_server: Res, + mut flame: ResMut, +) { + flame.onnx = asset_server.load("models/flame.onnx"); +} + + +fn setup( + mut commands: Commands, +) { + commands.spawn(FlameInput::default()); + commands.spawn(( + Camera3dBundle::default(), + PanOrbitCamera { + allow_upside_down: true, + ..default() + }, + )); +} + + +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( + mut commands: Commands, + mut meshes: ResMut>, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, +) { + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); + + commands.spawn(PbrBundle { + mesh: meshes.add(flame_output.mesh()), + ..default() + }); + } +} diff --git a/tools/lightglue.rs b/tools/lightglue.rs new file mode 100644 index 0000000..d2208fa --- /dev/null +++ b/tools/lightglue.rs @@ -0,0 +1,199 @@ +use bevy::{ + prelude::*, + window::PrimaryWindow, +}; + +use bevy_ort::{ + BevyOrtPlugin, + models::lightglue::{ + GluedPair, + lightglue_inference, + Lightglue, + LightgluePlugin, + }, + Onnx, +}; + + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + BevyOrtPlugin, + LightgluePlugin, + )) + .init_resource::() + .add_systems(Startup, load_lightglue) + .add_systems(Update, inference) + .run(); +} + + +#[derive(Resource, Default)] +pub struct LightglueInput { + pub a: Handle, + pub b: Handle, +} + + +fn load_lightglue( + asset_server: Res, + mut lightglue: ResMut, + mut input: ResMut, +) { + let lightglue_handle: Handle = asset_server.load("models/disk_lightglue_end2end_fused_cpu.onnx"); + lightglue.onnx = lightglue_handle; + + input.a = asset_server.load("images/sacre_coeur1.png"); + input.b = asset_server.load("images/sacre_coeur2.png"); +} + + +fn inference( + mut commands: Commands, + lightglue: Res, + input: Res, + onnx_assets: Res>, + images: Res>, + primary_window: Query<&Window, With>, + mut complete: Local, +) { + if *complete { + return; + } + + let window = primary_window.single(); + + let images = [ + images.get(&input.a).expect("failed to get image asset"), + images.get(&input.b).expect("failed to get image asset"), + ]; + + let glued_pairs: Result)>, String> = (|| { + let onnx = onnx_assets.get(&lightglue.onnx).ok_or("failed to get ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; + + Ok(lightglue_inference( + session, + &images, + )) + })(); + + match glued_pairs { + Ok(glued_pairs) => { + println!("glued_pairs: {:?}", glued_pairs[0].2.len()); + + commands.spawn(NodeBundle { + style: Style { + display: Display::Grid, + width: Val::Percent(100.0), + height: Val::Percent(100.0), + grid_template_columns: RepeatedGridTrack::flex(2, 1.0), + grid_template_rows: RepeatedGridTrack::flex(2, 1.0), + ..default() + }, + background_color: BackgroundColor(Color::DARK_GRAY), + ..default() + }) + .with_children(|builder| { + builder.spawn(ImageBundle { + style: Style { + ..default() + }, + image: UiImage::new(input.a.clone()), + ..default() + }) + .with_children(|builder| { + let image_width = images[0].width() as f32; + let image_height = images[0].height() as f32; + + let display_width = window.width() / 2.0; + let display_height = window.height() / 2.0; + + let scale_x = display_width / image_width; + let scale_y = display_height / image_height; + + glued_pairs[0].2.iter().for_each(|pair| { + let scaled_x = pair.from_x as f32 * scale_x; + let scaled_y = pair.from_y as f32 * scale_y; + + builder.spawn(NodeBundle { + style: Style { + position_type: PositionType::Absolute, + left: Val::Px(scaled_x), + top: Val::Px(scaled_y), + width: Val::Px(2.0), + height: Val::Px(2.0), + ..default() + }, + background_color: Color::rgb(1.0, 0.0, 0.0).into(), + ..default() + }); + }); + }); + + builder.spawn(ImageBundle { + style: Style { + ..default() + }, + image: UiImage::new(input.b.clone()), + ..default() + }) + .with_children(|builder| { + let image_width = images[1].width() as f32; + let image_height = images[1].height() as f32; + + let display_width = window.width() / 2.0; + let display_height = window.height() / 2.0; + + let scale_x = display_width / image_width; + let scale_y = display_height / image_height; + + glued_pairs[0].2.iter().for_each(|pair| { + let scaled_x = pair.to_x as f32 * scale_x; + let scaled_y = pair.to_y as f32 * scale_y; + + builder.spawn(NodeBundle { + style: Style { + position_type: PositionType::Absolute, + left: Val::Px(scaled_x), + top: Val::Px(scaled_y), + width: Val::Px(2.0), + height: Val::Px(2.0), + ..default() + }, + background_color: Color::rgb(0.0, 1.0, 0.0).into(), + ..default() + }); + }); + }); + + builder.spawn(ImageBundle { + style: Style { + ..default() + }, + image: UiImage::new(input.a.clone()), + ..default() + }); + + builder.spawn(ImageBundle { + style: Style { + ..default() + }, + image: UiImage::new(input.b.clone()), + ..default() + }); + + // TODO: draw lines between keypoints + }); + + commands.spawn(Camera2dBundle::default()); + + *complete = true; + } + Err(e) => { + eprintln!("inference failed: {}", e); + } + } +} diff --git a/tools/modnet.rs b/tools/modnet.rs index 43a6957..fbda1a5 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -2,10 +2,10 @@ use bevy::prelude::*; use bevy_ort::{ BevyOrtPlugin, - inputs, models::modnet::{ - images_to_modnet_input, - modnet_output_to_luma_images, + modnet_inference, + Modnet, + ModnetPlugin, }, Onnx, }; @@ -17,36 +17,37 @@ fn main() { .add_plugins(( DefaultPlugins, BevyOrtPlugin, + ModnetPlugin, )) - .init_resource::() + .init_resource::() .add_systems(Startup, load_modnet) .add_systems(Update, inference) .run(); } - #[derive(Resource, Default)] -pub struct Modnet { - pub onnx: Handle, - pub input: Handle, +struct ModnetInput { + image: Handle, } fn load_modnet( asset_server: Res, mut modnet: ResMut, + mut input: ResMut, ) { - let modnet_handle: Handle = asset_server.load("modnet_photographic_portrait_matting.onnx"); + let modnet_handle: Handle = asset_server.load("models/modnet_photographic_portrait_matting.onnx"); modnet.onnx = modnet_handle; - let input_handle: Handle = asset_server.load("person.png"); - modnet.input = input_handle; + let input_handle: Handle = asset_server.load("images/person.png"); + input.image = input_handle; } fn inference( mut commands: Commands, modnet: Res, + input: Res, onnx_assets: Res>, mut images: ResMut>, mut complete: Local, @@ -55,21 +56,14 @@ fn inference( return; } - let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(vec![&image], Some((256, 144))); + let image = images.get(&input.image).expect("failed to get image asset"); let mask_image: Result = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; - let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string())?; - let outputs = session.run(input_values).map_err(|e| e.to_string()); - - let binding = outputs.ok().unwrap(); - let output_value: &ort::Value = binding.get("output").unwrap(); - - Ok(modnet_output_to_luma_images(output_value).pop().unwrap()) + Ok(modnet_inference(session, &[image], None).pop().unwrap()) })(); match mask_image { @@ -103,7 +97,7 @@ fn inference( *complete = true; } Err(e) => { - println!("inference failed: {}", e); + eprintln!("inference failed: {}", e); } } } diff --git a/tools/yolo_v8.rs b/tools/yolo_v8.rs new file mode 100644 index 0000000..bd5d92f --- /dev/null +++ b/tools/yolo_v8.rs @@ -0,0 +1,135 @@ +use bevy::{ + prelude::*, + window::PrimaryWindow, +}; + +use bevy_ort::{ + BevyOrtPlugin, + models::yolo_v8::{ + yolo_inference, + BoundingBox, + Yolo, + YoloPlugin, + }, + Onnx, +}; + + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + BevyOrtPlugin, + YoloPlugin, + )) + .init_resource::() + .add_systems(Startup, load_yolo) + .add_systems(Update, inference) + .run(); +} + + +#[derive(Resource, Default)] +pub struct YoloInput { + pub image: Handle, +} + + +fn load_yolo( + asset_server: Res, + mut yolo: ResMut, + mut input: ResMut, +) { + let yolo_v8_handle: Handle = asset_server.load("models/yolov8n.onnx"); + yolo.onnx = yolo_v8_handle; + + let input_handle: Handle = asset_server.load("images/person.png"); + input.image = input_handle; +} + + +fn inference( + mut commands: Commands, + yolo: Res, + input: Res, + onnx_assets: Res>, + images: Res>, + primary_window: Query<&Window, With>, + mut complete: Local, +) { + if *complete { + return; + } + + let window = primary_window.single(); + + let image = images.get(&input.image).expect("failed to get image asset"); + + let bounding_boxes: Result, String> = (|| { + let onnx = onnx_assets.get(&yolo.onnx).ok_or("failed to get ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; + + Ok(yolo_inference( + session, + image, + 0.5, + )) + })(); + + match bounding_boxes { + Ok(bounding_boxes) => { + + commands.spawn(NodeBundle { + style: Style { + display: Display::Grid, + width: Val::Percent(100.0), + height: Val::Percent(100.0), + grid_template_columns: RepeatedGridTrack::flex(1, 1.0), + grid_template_rows: RepeatedGridTrack::flex(1, 1.0), + ..default() + }, + background_color: BackgroundColor(Color::DARK_GRAY), + ..default() + }) + .with_children(|builder| { + builder.spawn(ImageBundle { + style: Style { + ..default() + }, + image: UiImage::new(input.image.clone()), + ..default() + }); + + if let Some(first_box) = bounding_boxes.first() { + let x1 = first_box.x1 / image.width() as f32 * window.width(); + let y1 = first_box.y1 / image.height() as f32 * window.height(); + + let bb_width = (first_box.x2 - first_box.x1) / image.width() as f32 * window.width(); + let bb_height = (first_box.y2 - first_box.y1) / image.height() as f32 * window.height(); + + builder.spawn(NodeBundle { + style: Style { + position_type: PositionType::Absolute, + left: Val::Px(x1), + top: Val::Px(y1), + width: Val::Px(bb_width), + height: Val::Px(bb_height), + border: UiRect::all(Val::Px(2.0)), + ..default() + }, + background_color: BackgroundColor(Color::rgba(1.0, 0.0, 0.0, 0.5)), + ..default() + }); + } + }); + + commands.spawn(Camera2dBundle::default()); + + *complete = true; + } + Err(e) => { + eprintln!("inference failed: {}", e); + } + } +}