This is page 5 of 9. Use http://codebase.md/mehmetoguzderin/shaderc-vkrunner-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .devcontainer
│ ├── devcontainer.json
│ ├── docker-compose.yml
│ └── Dockerfile
├── .gitattributes
├── .github
│ └── workflows
│ └── build-push-image.yml
├── .gitignore
├── .vscode
│ └── mcp.json
├── Cargo.lock
├── Cargo.toml
├── Dockerfile
├── LICENSE
├── README.adoc
├── shaderc-vkrunner-mcp.jpg
├── src
│ └── main.rs
└── vkrunner
├── .editorconfig
├── .gitignore
├── .gitlab-ci.yml
├── build.rs
├── Cargo.toml
├── COPYING
├── examples
│ ├── compute-shader.shader_test
│ ├── cooperative-matrix.shader_test
│ ├── depth-buffer.shader_test
│ ├── desc_set_and_binding.shader_test
│ ├── entrypoint.shader_test
│ ├── float-framebuffer.shader_test
│ ├── frexp.shader_test
│ ├── geometry.shader_test
│ ├── indices.shader_test
│ ├── layouts.shader_test
│ ├── properties.shader_test
│ ├── push-constants.shader_test
│ ├── require-subgroup-size.shader_test
│ ├── row-major.shader_test
│ ├── spirv.shader_test
│ ├── ssbo.shader_test
│ ├── tolerance.shader_test
│ ├── tricolore.shader_test
│ ├── ubo.shader_test
│ ├── vertex-data-piglit.shader_test
│ └── vertex-data.shader_test
├── include
│ ├── vk_video
│ │ ├── vulkan_video_codec_av1std_decode.h
│ │ ├── vulkan_video_codec_av1std_encode.h
│ │ ├── vulkan_video_codec_av1std.h
│ │ ├── vulkan_video_codec_h264std_decode.h
│ │ ├── vulkan_video_codec_h264std_encode.h
│ │ ├── vulkan_video_codec_h264std.h
│ │ ├── vulkan_video_codec_h265std_decode.h
│ │ ├── vulkan_video_codec_h265std_encode.h
│ │ ├── vulkan_video_codec_h265std.h
│ │ └── vulkan_video_codecs_common.h
│ └── vulkan
│ ├── vk_platform.h
│ ├── vulkan_core.h
│ └── vulkan.h
├── precompile-script.py
├── README.md
├── scripts
│ └── update-vulkan.sh
├── src
│ └── main.rs
├── test-build.sh
└── vkrunner
├── allocate_store.rs
├── buffer.rs
├── compiler
│ └── fake_process.rs
├── compiler.rs
├── config.rs
├── context.rs
├── enum_table.rs
├── env_var_test.rs
├── executor.rs
├── fake_vulkan.rs
├── features.rs
├── flush_memory.rs
├── format_table.rs
├── format.rs
├── half_float.rs
├── hex.rs
├── inspect.rs
├── lib.rs
├── logger.rs
├── make-enums.py
├── make-features.py
├── make-formats.py
├── make-pipeline-key-data.py
├── make-vulkan-funcs-data.py
├── parse_num.rs
├── pipeline_key_data.rs
├── pipeline_key.rs
├── pipeline_set.rs
├── requirements.rs
├── result.rs
├── script.rs
├── shader_stage.rs
├── slot.rs
├── small_float.rs
├── source.rs
├── stream.rs
├── temp_file.rs
├── tester.rs
├── tolerance.rs
├── util.rs
├── vbo.rs
├── vk.rs
├── vulkan_funcs_data.rs
├── vulkan_funcs.rs
├── window_format.rs
└── window.rs
```
# Files
--------------------------------------------------------------------------------
/vkrunner/vkrunner/pipeline_set.rs:
--------------------------------------------------------------------------------
```rust
1 | // vkrunner
2 | //
3 | // Copyright (C) 2018 Intel Corporation
4 | // Copypright 2023 Neil Roberts
5 | //
6 | // Permission is hereby granted, free of charge, to any person obtaining a
7 | // copy of this software and associated documentation files (the "Software"),
8 | // to deal in the Software without restriction, including without limitation
9 | // the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 | // and/or sell copies of the Software, and to permit persons to whom the
11 | // Software is furnished to do so, subject to the following conditions:
12 | //
13 | // The above copyright notice and this permission notice (including the next
14 | // paragraph) shall be included in all copies or substantial portions of the
15 | // Software.
16 | //
17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
20 | // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | // DEALINGS IN THE SOFTWARE.
24 |
25 | use crate::window::Window;
26 | use crate::compiler;
27 | use crate::shader_stage;
28 | use crate::vk;
29 | use crate::script::{Script, Buffer, BufferType, Operation};
30 | use crate::pipeline_key;
31 | use crate::logger::Logger;
32 | use crate::vbo::Vbo;
33 | use std::rc::Rc;
34 | use std::ptr;
35 | use std::mem;
36 | use std::fmt;
37 |
38 | #[derive(Debug)]
39 | pub struct PipelineSet {
40 | pipelines: PipelineVec,
41 | layout: PipelineLayout,
42 | // The descriptor data is only created if there are buffers in the
43 | // script
44 | descriptor_data: Option<DescriptorData>,
45 | stages: vk::VkShaderStageFlagBits,
46 | // These are never read but they should probably be kept alive
47 | // until the pipelines are destroyed.
48 | _pipeline_cache: PipelineCache,
49 | _modules: [Option<ShaderModule>; shader_stage::N_STAGES],
50 | }
51 |
52 | #[repr(C)]
53 | #[derive(Debug)]
54 | pub struct RectangleVertex {
55 | pub x: f32,
56 | pub y: f32,
57 | pub z: f32,
58 | }
59 |
60 | /// An error that can be returned by [PipelineSet::new].
61 | #[derive(Debug)]
62 | pub enum Error {
63 | /// Compiling one of the shaders in the script failed
64 | CompileError(compiler::Error),
65 | /// vkCreatePipelineCache failed
66 | CreatePipelineCacheFailed,
67 | /// vkCreateDescriptorPool failed
68 | CreateDescriptorPoolFailed,
69 | /// vkCreateDescriptorSetLayout failed
70 | CreateDescriptorSetLayoutFailed,
71 | /// vkCreatePipelineLayout failed
72 | CreatePipelineLayoutFailed,
73 | /// vkCreatePipeline failed
74 | CreatePipelineFailed,
75 | }
76 |
77 | impl From<compiler::Error> for Error {
78 | fn from(e: compiler::Error) -> Error {
79 | Error::CompileError(e)
80 | }
81 | }
82 |
83 | impl fmt::Display for Error {
84 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
85 | match self {
86 | Error::CompileError(e) => e.fmt(f),
87 | Error::CreatePipelineCacheFailed => {
88 | write!(f, "vkCreatePipelineCache failed")
89 | },
90 | Error::CreateDescriptorPoolFailed => {
91 | write!(f, "vkCreateDescriptorPool failed")
92 | },
93 | Error::CreateDescriptorSetLayoutFailed => {
94 | write!(f, "vkCreateDescriptorSetLayout failed")
95 | },
96 | Error::CreatePipelineLayoutFailed => {
97 | write!(f, "vkCreatePipelineLayout failed")
98 | },
99 | Error::CreatePipelineFailed => {
100 | write!(f, "Pipeline creation function failed")
101 | },
102 | }
103 | }
104 | }
105 |
106 | #[derive(Debug)]
107 | struct ShaderModule {
108 | handle: vk::VkShaderModule,
109 | // needed for the destructor
110 | window: Rc<Window>,
111 | }
112 |
113 | impl Drop for ShaderModule {
114 | fn drop(&mut self) {
115 | unsafe {
116 | self.window.device().vkDestroyShaderModule.unwrap()(
117 | self.window.vk_device(),
118 | self.handle,
119 | ptr::null(), // allocator
120 | );
121 | }
122 | }
123 | }
124 |
125 | #[derive(Debug)]
126 | struct PipelineCache {
127 | handle: vk::VkPipelineCache,
128 | // needed for the destructor
129 | window: Rc<Window>,
130 | }
131 |
132 | impl Drop for PipelineCache {
133 | fn drop(&mut self) {
134 | unsafe {
135 | self.window.device().vkDestroyPipelineCache.unwrap()(
136 | self.window.vk_device(),
137 | self.handle,
138 | ptr::null(), // allocator
139 | );
140 | }
141 | }
142 | }
143 |
144 | impl PipelineCache {
145 | fn new(window: Rc<Window>) -> Result<PipelineCache, Error> {
146 | let create_info = vk::VkPipelineCacheCreateInfo {
147 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
148 | flags: 0,
149 | pNext: ptr::null(),
150 | initialDataSize: 0,
151 | pInitialData: ptr::null(),
152 | };
153 |
154 | let mut handle = vk::null_handle();
155 |
156 | let res = unsafe {
157 | window.device().vkCreatePipelineCache.unwrap()(
158 | window.vk_device(),
159 | ptr::addr_of!(create_info),
160 | ptr::null(), // allocator
161 | ptr::addr_of_mut!(handle),
162 | )
163 | };
164 |
165 | if res == vk::VK_SUCCESS {
166 | Ok(PipelineCache { handle, window })
167 | } else {
168 | Err(Error::CreatePipelineCacheFailed)
169 | }
170 | }
171 | }
172 |
173 | #[derive(Debug)]
174 | struct DescriptorPool {
175 | handle: vk::VkDescriptorPool,
176 | // needed for the destructor
177 | window: Rc<Window>,
178 | }
179 |
180 | impl Drop for DescriptorPool {
181 | fn drop(&mut self) {
182 | unsafe {
183 | self.window.device().vkDestroyDescriptorPool.unwrap()(
184 | self.window.vk_device(),
185 | self.handle,
186 | ptr::null(), // allocator
187 | );
188 | }
189 | }
190 | }
191 |
192 | impl DescriptorPool {
193 | fn new(
194 | window: Rc<Window>,
195 | buffers: &[Buffer],
196 | ) -> Result<DescriptorPool, Error> {
197 | let mut n_ubos = 0;
198 | let mut n_ssbos = 0;
199 |
200 | for buffer in buffers {
201 | match buffer.buffer_type {
202 | BufferType::Ubo => n_ubos += 1,
203 | BufferType::Ssbo => n_ssbos += 1,
204 | }
205 | }
206 |
207 | let mut pool_sizes = Vec::<vk::VkDescriptorPoolSize>::new();
208 |
209 | if n_ubos > 0 {
210 | pool_sizes.push(vk::VkDescriptorPoolSize {
211 | type_: vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
212 | descriptorCount: n_ubos,
213 | });
214 | }
215 | if n_ssbos > 0 {
216 | pool_sizes.push(vk::VkDescriptorPoolSize {
217 | type_: vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
218 | descriptorCount: n_ssbos,
219 | });
220 | }
221 |
222 | // The descriptor pool shouldn’t have been created if there
223 | // were no buffers
224 | assert!(!pool_sizes.is_empty());
225 |
226 | let create_info = vk::VkDescriptorPoolCreateInfo {
227 | sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,
228 | flags: vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT,
229 | pNext: ptr::null(),
230 | maxSets: n_desc_sets(buffers) as u32,
231 | poolSizeCount: pool_sizes.len() as u32,
232 | pPoolSizes: pool_sizes.as_ptr(),
233 | };
234 |
235 | let mut handle = vk::null_handle();
236 |
237 | let res = unsafe {
238 | window.device().vkCreateDescriptorPool.unwrap()(
239 | window.vk_device(),
240 | ptr::addr_of!(create_info),
241 | ptr::null(), // allocator
242 | ptr::addr_of_mut!(handle),
243 | )
244 | };
245 |
246 | if res == vk::VK_SUCCESS {
247 | Ok(DescriptorPool { handle, window })
248 | } else {
249 | Err(Error::CreateDescriptorPoolFailed)
250 | }
251 | }
252 | }
253 |
254 | #[derive(Debug)]
255 | struct DescriptorSetLayoutVec {
256 | handles: Vec<vk::VkDescriptorSetLayout>,
257 | // needed for the destructor
258 | window: Rc<Window>,
259 | }
260 |
261 | impl Drop for DescriptorSetLayoutVec {
262 | fn drop(&mut self) {
263 | for &handle in self.handles.iter() {
264 | unsafe {
265 | self.window.device().vkDestroyDescriptorSetLayout.unwrap()(
266 | self.window.vk_device(),
267 | handle,
268 | ptr::null(), // allocator
269 | );
270 | }
271 | }
272 | }
273 | }
274 |
275 | impl DescriptorSetLayoutVec {
276 | fn new(window: Rc<Window>) -> DescriptorSetLayoutVec {
277 | DescriptorSetLayoutVec { window, handles: Vec::new() }
278 | }
279 |
280 | fn len(&self) -> usize {
281 | self.handles.len()
282 | }
283 |
284 | fn as_ptr(&self) -> *const vk::VkDescriptorSetLayout {
285 | self.handles.as_ptr()
286 | }
287 |
288 | fn add(
289 | &mut self,
290 | bindings: &[vk::VkDescriptorSetLayoutBinding],
291 | ) -> Result<(), Error> {
292 | let create_info = vk::VkDescriptorSetLayoutCreateInfo {
293 | sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
294 | flags: 0,
295 | pNext: ptr::null(),
296 | bindingCount: bindings.len() as u32,
297 | pBindings: bindings.as_ptr(),
298 | };
299 |
300 | let mut handle = vk::null_handle();
301 |
302 | let res = unsafe {
303 | self.window.device().vkCreateDescriptorSetLayout.unwrap()(
304 | self.window.vk_device(),
305 | ptr::addr_of!(create_info),
306 | ptr::null(), // allocator
307 | ptr::addr_of_mut!(handle),
308 | )
309 | };
310 |
311 | if res == vk::VK_SUCCESS {
312 | self.handles.push(handle);
313 | Ok(())
314 | } else {
315 | Err(Error::CreateDescriptorSetLayoutFailed)
316 | }
317 | }
318 | }
319 |
320 | fn create_descriptor_set_layouts(
321 | window: Rc<Window>,
322 | buffers: &[Buffer],
323 | stages: vk::VkShaderStageFlags,
324 | ) -> Result<DescriptorSetLayoutVec, Error> {
325 | let n_desc_sets = n_desc_sets(buffers);
326 | let mut layouts = DescriptorSetLayoutVec::new(window);
327 | let mut bindings = Vec::new();
328 | let mut buffer_num = 0;
329 |
330 | for desc_set in 0..n_desc_sets {
331 | bindings.clear();
332 |
333 | while buffer_num < buffers.len()
334 | && buffers[buffer_num].desc_set as usize == desc_set
335 | {
336 | let buffer = &buffers[buffer_num];
337 |
338 | bindings.push(vk::VkDescriptorSetLayoutBinding {
339 | binding: buffer.binding,
340 | descriptorType: match buffer.buffer_type {
341 | BufferType::Ubo => vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
342 | BufferType::Ssbo => vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
343 | },
344 | descriptorCount: 1,
345 | stageFlags: stages,
346 | pImmutableSamplers: ptr::null(),
347 | });
348 |
349 | buffer_num += 1;
350 | }
351 |
352 | layouts.add(&bindings)?;
353 | }
354 |
355 | assert_eq!(layouts.len(), n_desc_sets);
356 |
357 | Ok(layouts)
358 | }
359 |
360 | fn n_desc_sets(buffers: &[Buffer]) -> usize {
361 | match buffers.last() {
362 | // The number of descriptor sets is the highest used
363 | // descriptor set index + 1. The buffers are in order so the
364 | // highest one should be the last one.
365 | Some(last) => last.desc_set as usize + 1,
366 | None => 0,
367 | }
368 | }
369 |
370 | #[derive(Debug)]
371 | struct DescriptorData {
372 | pool: DescriptorPool,
373 | layouts: DescriptorSetLayoutVec,
374 | }
375 |
376 | impl DescriptorData {
377 | fn new(
378 | window: &Rc<Window>,
379 | buffers: &[Buffer],
380 | stages: vk::VkShaderStageFlagBits,
381 | ) -> Result<DescriptorData, Error> {
382 | let pool = DescriptorPool::new(
383 | Rc::clone(&window),
384 | buffers,
385 | )?;
386 |
387 | let layouts = create_descriptor_set_layouts(
388 | Rc::clone(&window),
389 | buffers,
390 | stages,
391 | )?;
392 |
393 | Ok(DescriptorData { pool, layouts })
394 | }
395 | }
396 |
397 | fn compile_shaders(
398 | logger: &mut Logger,
399 | window: &Rc<Window>,
400 | script: &Script,
401 | show_disassembly: bool,
402 | ) -> Result<[Option<ShaderModule>; shader_stage::N_STAGES], Error> {
403 | let mut modules: [Option<ShaderModule>; shader_stage::N_STAGES] =
404 | Default::default();
405 |
406 | for &stage in shader_stage::ALL_STAGES.iter() {
407 | if script.shaders(stage).is_empty() {
408 | continue;
409 | }
410 |
411 | modules[stage as usize] = Some(ShaderModule {
412 | handle: compiler::build_stage(
413 | logger,
414 | window.context(),
415 | script,
416 | stage,
417 | show_disassembly,
418 | )?,
419 | window: Rc::clone(window),
420 | });
421 | }
422 |
423 | Ok(modules)
424 | }
425 |
426 | fn stage_flags(script: &Script) -> vk::VkShaderStageFlagBits {
427 | // Set a flag for each stage that has a shader in the script
428 | shader_stage::ALL_STAGES
429 | .iter()
430 | .filter_map(|&stage| if script.shaders(stage).is_empty() {
431 | None
432 | } else {
433 | Some(stage.flag())
434 | })
435 | .fold(0, |a, b| a | b)
436 | }
437 |
438 | fn push_constant_size(script: &Script) -> usize {
439 | script.commands()
440 | .iter()
441 | .map(|command| match &command.op {
442 | Operation::SetPushCommand { offset, data } => offset + data.len(),
443 | _ => 0,
444 | })
445 | .max()
446 | .unwrap_or(0)
447 | }
448 |
449 | #[derive(Debug)]
450 | struct PipelineLayout {
451 | handle: vk::VkPipelineLayout,
452 | // needed for the destructor
453 | window: Rc<Window>,
454 | }
455 |
456 | impl Drop for PipelineLayout {
457 | fn drop(&mut self) {
458 | unsafe {
459 | self.window.device().vkDestroyPipelineLayout.unwrap()(
460 | self.window.vk_device(),
461 | self.handle,
462 | ptr::null(), // allocator
463 | );
464 | }
465 | }
466 | }
467 |
468 | impl PipelineLayout {
469 | fn new(
470 | window: Rc<Window>,
471 | script: &Script,
472 | stages: vk::VkShaderStageFlagBits,
473 | descriptor_data: Option<&DescriptorData>
474 | ) -> Result<PipelineLayout, Error> {
475 | let mut handle = vk::null_handle();
476 |
477 | let push_constant_range = vk::VkPushConstantRange {
478 | stageFlags: stages,
479 | offset: 0,
480 | size: push_constant_size(script) as u32,
481 | };
482 |
483 | let mut create_info = vk::VkPipelineLayoutCreateInfo::default();
484 |
485 | create_info.sType = vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
486 |
487 | if push_constant_range.size > 0 {
488 | create_info.pushConstantRangeCount = 1;
489 | create_info.pPushConstantRanges =
490 | ptr::addr_of!(push_constant_range);
491 | }
492 |
493 | if let Some(descriptor_data) = descriptor_data {
494 | create_info.setLayoutCount = descriptor_data.layouts.len() as u32;
495 | create_info.pSetLayouts = descriptor_data.layouts.as_ptr();
496 | }
497 |
498 | let res = unsafe {
499 | window.device().vkCreatePipelineLayout.unwrap()(
500 | window.vk_device(),
501 | ptr::addr_of!(create_info),
502 | ptr::null(), // allocator
503 | ptr::addr_of_mut!(handle),
504 | )
505 | };
506 |
507 | if res == vk::VK_SUCCESS {
508 | Ok(PipelineLayout { handle, window })
509 | } else {
510 | Err(Error::CreatePipelineLayoutFailed)
511 | }
512 | }
513 | }
514 |
515 | #[derive(Debug)]
516 | struct VertexInputState {
517 | create_info: vk::VkPipelineVertexInputStateCreateInfo,
518 | // These are not read put the create_info struct keeps pointers to
519 | // them so they need to be kept alive
520 | _input_bindings: Vec::<vk::VkVertexInputBindingDescription>,
521 | _attribs: Vec::<vk::VkVertexInputAttributeDescription>,
522 | }
523 |
524 | impl VertexInputState {
525 | fn new(script: &Script, key: &pipeline_key::Key) -> VertexInputState {
526 | let mut input_bindings = Vec::new();
527 | let mut attribs = Vec::new();
528 |
529 | match key.source() {
530 | pipeline_key::Source::Rectangle => {
531 | input_bindings.push(vk::VkVertexInputBindingDescription {
532 | binding: 0,
533 | stride: mem::size_of::<RectangleVertex>() as u32,
534 | inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
535 | });
536 | attribs.push(vk::VkVertexInputAttributeDescription {
537 | location: 0,
538 | binding: 0,
539 | format: vk::VK_FORMAT_R32G32B32_SFLOAT,
540 | offset: 0,
541 | });
542 | },
543 | pipeline_key::Source::VertexData => {
544 | if let Some(vbo) = script.vertex_data() {
545 | VertexInputState::set_up_vertex_data_attribs(
546 | &mut input_bindings,
547 | &mut attribs,
548 | vbo
549 | );
550 | }
551 | }
552 | }
553 |
554 | VertexInputState {
555 | create_info: vk::VkPipelineVertexInputStateCreateInfo {
556 | sType:
557 | vk::VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
558 | flags: 0,
559 | pNext: ptr::null(),
560 | vertexBindingDescriptionCount: input_bindings.len() as u32,
561 | pVertexBindingDescriptions: input_bindings.as_ptr(),
562 | vertexAttributeDescriptionCount: attribs.len() as u32,
563 | pVertexAttributeDescriptions: attribs.as_ptr(),
564 | },
565 | _input_bindings: input_bindings,
566 | _attribs: attribs,
567 | }
568 | }
569 |
570 | fn set_up_vertex_data_attribs(
571 | input_bindings: &mut Vec::<vk::VkVertexInputBindingDescription>,
572 | attribs: &mut Vec::<vk::VkVertexInputAttributeDescription>,
573 | vbo: &Vbo,
574 | ) {
575 | input_bindings.push(vk::VkVertexInputBindingDescription {
576 | binding: 0,
577 | stride: vbo.stride() as u32,
578 | inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
579 | });
580 |
581 | for attrib in vbo.attribs().iter() {
582 | attribs.push(vk::VkVertexInputAttributeDescription {
583 | location: attrib.location(),
584 | binding: 0,
585 | format: attrib.format().vk_format,
586 | offset: attrib.offset() as u32,
587 | });
588 | }
589 | }
590 | }
591 |
592 | #[derive(Debug)]
593 | struct PipelineVec {
594 | handles: Vec<vk::VkPipeline>,
595 | // needed for the destructor
596 | window: Rc<Window>,
597 | }
598 |
599 | impl Drop for PipelineVec {
600 | fn drop(&mut self) {
601 | for &handle in self.handles.iter() {
602 | unsafe {
603 | self.window.device().vkDestroyPipeline.unwrap()(
604 | self.window.vk_device(),
605 | handle,
606 | ptr::null(), // allocator
607 | );
608 | }
609 | }
610 | }
611 | }
612 |
613 | impl PipelineVec {
614 | fn new(
615 | window: Rc<Window>,
616 | script: &Script,
617 | pipeline_cache: vk::VkPipelineCache,
618 | layout: vk::VkPipelineLayout,
619 | modules: &[Option<ShaderModule>],
620 | ) -> Result<PipelineVec, Error> {
621 | let mut vec = PipelineVec { window, handles: Vec::new() };
622 |
623 | let mut first_graphics_pipeline: Option<vk::VkPipeline> = None;
624 |
625 | for key in script.pipeline_keys().iter() {
626 | let pipeline = match key.pipeline_type() {
627 | pipeline_key::Type::Graphics => {
628 | let allow_derivatives =
629 | first_graphics_pipeline.is_none()
630 | && script.pipeline_keys().len() > 1;
631 |
632 | let pipeline = PipelineVec::create_graphics_pipeline(
633 | &vec.window,
634 | script,
635 | key,
636 | pipeline_cache,
637 | layout,
638 | modules,
639 | allow_derivatives,
640 | first_graphics_pipeline,
641 | )?;
642 |
643 | first_graphics_pipeline.get_or_insert(pipeline);
644 |
645 | pipeline
646 | },
647 | pipeline_key::Type::Compute => {
648 | PipelineVec::create_compute_pipeline(
649 | &vec.window,
650 | key,
651 | pipeline_cache,
652 | layout,
653 | modules[shader_stage::Stage::Compute as usize]
654 | .as_ref()
655 | .map(|m| m.handle)
656 | .unwrap_or(vk::null_handle()),
657 | script.requirements().required_subgroup_size,
658 | )?
659 | },
660 | };
661 |
662 | vec.handles.push(pipeline);
663 | }
664 |
665 | Ok(vec)
666 | }
667 |
668 | fn null_terminated_entrypoint(
669 | key: &pipeline_key::Key,
670 | stage: shader_stage::Stage,
671 | ) -> String {
672 | let mut entrypoint = key.entrypoint(stage).to_string();
673 | entrypoint.push('\0');
674 | entrypoint
675 | }
676 |
677 | fn create_stages(
678 | modules: &[Option<ShaderModule>],
679 | entrypoints: &[String],
680 | ) -> Vec<vk::VkPipelineShaderStageCreateInfo> {
681 | let mut stages = Vec::new();
682 |
683 | for &stage in shader_stage::ALL_STAGES.iter() {
684 | if stage == shader_stage::Stage::Compute {
685 | continue;
686 | }
687 |
688 | let module = match &modules[stage as usize] {
689 | Some(module) => module.handle,
690 | None => continue,
691 | };
692 |
693 | stages.push(vk::VkPipelineShaderStageCreateInfo {
694 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
695 | flags: 0,
696 | pNext: ptr::null(),
697 | stage: stage.flag(),
698 | module,
699 | pName: entrypoints[stage as usize].as_ptr().cast(),
700 | pSpecializationInfo: ptr::null(),
701 | });
702 | }
703 |
704 | stages
705 | }
706 |
707 | fn create_graphics_pipeline(
708 | window: &Window,
709 | script: &Script,
710 | key: &pipeline_key::Key,
711 | pipeline_cache: vk::VkPipelineCache,
712 | layout: vk::VkPipelineLayout,
713 | modules: &[Option<ShaderModule>],
714 | allow_derivatives: bool,
715 | parent_pipeline: Option<vk::VkPipeline>,
716 | ) -> Result<vk::VkPipeline, Error> {
717 | let create_info_buf = key.to_create_info();
718 |
719 | let mut create_info = vk::VkGraphicsPipelineCreateInfo::default();
720 |
721 | unsafe {
722 | ptr::copy_nonoverlapping(
723 | create_info_buf.as_ptr().cast(),
724 | ptr::addr_of_mut!(create_info),
725 | 1, // count
726 | );
727 | }
728 |
729 | let entrypoints = shader_stage::ALL_STAGES
730 | .iter()
731 | .map(|&s| PipelineVec::null_terminated_entrypoint(key, s))
732 | .collect::<Vec<String>>();
733 | let stages = PipelineVec::create_stages(modules, &entrypoints);
734 |
735 | let window_format = window.format();
736 |
737 | let viewport = vk::VkViewport {
738 | x: 0.0,
739 | y: 0.0,
740 | width: window_format.width as f32,
741 | height: window_format.height as f32,
742 | minDepth: 0.0,
743 | maxDepth: 1.0,
744 | };
745 |
746 | let scissor = vk::VkRect2D {
747 | offset: vk::VkOffset2D {
748 | x: 0,
749 | y: 0,
750 | },
751 | extent: vk::VkExtent2D {
752 | width: window_format.width as u32,
753 | height: window_format.height as u32,
754 | },
755 | };
756 |
757 | let viewport_state = vk::VkPipelineViewportStateCreateInfo {
758 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO,
759 | flags: 0,
760 | pNext: ptr::null(),
761 | viewportCount: 1,
762 | pViewports: ptr::addr_of!(viewport),
763 | scissorCount: 1,
764 | pScissors: ptr::addr_of!(scissor),
765 | };
766 |
767 | let multisample_state = vk::VkPipelineMultisampleStateCreateInfo {
768 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO,
769 | pNext: ptr::null(),
770 | flags: 0,
771 | rasterizationSamples: vk::VK_SAMPLE_COUNT_1_BIT,
772 | sampleShadingEnable: vk::VK_FALSE,
773 | minSampleShading: 0.0,
774 | pSampleMask: ptr::null(),
775 | alphaToCoverageEnable: vk::VK_FALSE,
776 | alphaToOneEnable: vk::VK_FALSE,
777 | };
778 |
779 | create_info.pViewportState = ptr::addr_of!(viewport_state);
780 | create_info.pMultisampleState = ptr::addr_of!(multisample_state);
781 | create_info.subpass = 0;
782 | create_info.basePipelineHandle =
783 | parent_pipeline.unwrap_or(vk::null_handle());
784 | create_info.basePipelineIndex = -1;
785 |
786 | create_info.stageCount = stages.len() as u32;
787 | create_info.pStages = stages.as_ptr();
788 | create_info.layout = layout;
789 | create_info.renderPass = window.render_passes()[0];
790 |
791 | if allow_derivatives {
792 | create_info.flags |= vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
793 | }
794 | if parent_pipeline.is_some() {
795 | create_info.flags |= vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT;
796 | }
797 |
798 | if modules[shader_stage::Stage::TessCtrl as usize].is_none()
799 | && modules[shader_stage::Stage::TessEval as usize].is_none()
800 | {
801 | create_info.pTessellationState = ptr::null();
802 | }
803 |
804 | let vertex_input_state = VertexInputState::new(script, key);
805 | create_info.pVertexInputState =
806 | ptr::addr_of!(vertex_input_state.create_info);
807 |
808 | let mut handle = vk::null_handle();
809 |
810 | let res = unsafe {
811 | window.device().vkCreateGraphicsPipelines.unwrap()(
812 | window.vk_device(),
813 | pipeline_cache,
814 | 1, // nCreateInfos
815 | ptr::addr_of!(create_info),
816 | ptr::null(), // allocator
817 | ptr::addr_of_mut!(handle),
818 | )
819 | };
820 |
821 | if res == vk::VK_SUCCESS {
822 | Ok(handle)
823 | } else {
824 | Err(Error::CreatePipelineFailed)
825 | }
826 | }
827 |
828 | fn create_compute_pipeline(
829 | window: &Window,
830 | key: &pipeline_key::Key,
831 | pipeline_cache: vk::VkPipelineCache,
832 | layout: vk::VkPipelineLayout,
833 | module: vk::VkShaderModule,
834 | required_subgroup_size: Option<u32>,
835 | ) -> Result<vk::VkPipeline, Error> {
836 | let entrypoint = PipelineVec::null_terminated_entrypoint(
837 | key,
838 | shader_stage::Stage::Compute,
839 | );
840 |
841 | let rss_info = required_subgroup_size.map(|size|
842 | vk::VkPipelineShaderStageRequiredSubgroupSizeCreateInfo {
843 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
844 | pNext: ptr::null_mut(),
845 | requiredSubgroupSize: size,
846 | });
847 |
848 | let create_info = vk::VkComputePipelineCreateInfo {
849 | sType: vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
850 | pNext: ptr::null(),
851 | flags: 0,
852 |
853 | stage: vk::VkPipelineShaderStageCreateInfo {
854 | sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
855 | pNext: if let Some(ref rss_info) = rss_info {
856 | ptr::addr_of!(*rss_info).cast()
857 | } else {
858 | ptr::null_mut()
859 | },
860 | flags: 0,
861 | stage: vk::VK_SHADER_STAGE_COMPUTE_BIT,
862 | module,
863 | pName: entrypoint.as_ptr().cast(),
864 | pSpecializationInfo: ptr::null(),
865 | },
866 |
867 | layout,
868 | basePipelineHandle: vk::null_handle(),
869 | basePipelineIndex: -1,
870 | };
871 |
872 | let mut handle = vk::null_handle();
873 |
874 | let res = unsafe {
875 | window.device().vkCreateComputePipelines.unwrap()(
876 | window.vk_device(),
877 | pipeline_cache,
878 | 1, // nCreateInfos
879 | ptr::addr_of!(create_info),
880 | ptr::null(), // allocator
881 | ptr::addr_of_mut!(handle),
882 | )
883 | };
884 |
885 | if res == vk::VK_SUCCESS {
886 | Ok(handle)
887 | } else {
888 | Err(Error::CreatePipelineFailed)
889 | }
890 | }
891 | }
892 |
893 | impl PipelineSet {
894 | pub fn new(
895 | logger: &mut Logger,
896 | window: Rc<Window>,
897 | script: &Script,
898 | show_disassembly: bool,
899 | ) -> Result<PipelineSet, Error> {
900 | let modules = compile_shaders(
901 | logger,
902 | &window,
903 | script,
904 | show_disassembly
905 | )?;
906 |
907 | let pipeline_cache = PipelineCache::new(Rc::clone(&window))?;
908 |
909 | let stages = stage_flags(script);
910 |
911 | let descriptor_data = if script.buffers().is_empty() {
912 | None
913 | } else {
914 | Some(DescriptorData::new(&window, script.buffers(), stages)?)
915 | };
916 |
917 | let layout = PipelineLayout::new(
918 | Rc::clone(&window),
919 | script,
920 | stages,
921 | descriptor_data.as_ref(),
922 | )?;
923 |
924 | let pipelines = PipelineVec::new(
925 | Rc::clone(&window),
926 | script,
927 | pipeline_cache.handle,
928 | layout.handle,
929 | &modules,
930 | )?;
931 |
932 | Ok(PipelineSet {
933 | _modules: modules,
934 | _pipeline_cache: pipeline_cache,
935 | stages,
936 | descriptor_data,
937 | layout,
938 | pipelines,
939 | })
940 | }
941 |
942 | pub fn descriptor_set_layouts(&self) -> &[vk::VkDescriptorSetLayout] {
943 | match self.descriptor_data.as_ref() {
944 | Some(data) => &data.layouts.handles,
945 | None => &[],
946 | }
947 | }
948 |
949 | pub fn stages(&self) -> vk::VkShaderStageFlagBits {
950 | self.stages
951 | }
952 |
953 | pub fn layout(&self) -> vk::VkPipelineLayout {
954 | self.layout.handle
955 | }
956 |
957 | pub fn pipelines(&self) -> &[vk::VkPipeline] {
958 | &self.pipelines.handles
959 | }
960 |
961 | pub fn descriptor_pool(&self) -> Option<vk::VkDescriptorPool> {
962 | self.descriptor_data.as_ref().map(|data| data.pool.handle)
963 | }
964 | }
965 |
966 | #[cfg(test)]
967 | mod test {
968 | use super::*;
969 | use crate::fake_vulkan::{FakeVulkan, HandleType, PipelineCreateInfo};
970 | use crate::fake_vulkan::GraphicsPipelineCreateInfo;
971 | use crate::fake_vulkan::PipelineLayoutCreateInfo;
972 | use crate::context::Context;
973 | use crate::requirements::Requirements;
974 | use crate::source::Source;
975 | use crate::config::Config;
976 |
977 | #[derive(Debug)]
978 | struct TestData {
979 | pipeline_set: PipelineSet,
980 | window: Rc<Window>,
981 | fake_vulkan: Box<FakeVulkan>,
982 | }
983 |
984 | impl TestData {
985 | fn new_with_errors<F: FnOnce(&mut FakeVulkan)>(
986 | source: &str,
987 | queue_errors: F
988 | ) -> Result<TestData, Error> {
989 | let mut fake_vulkan = FakeVulkan::new();
990 |
991 | fake_vulkan.physical_devices.push(Default::default());
992 | fake_vulkan.physical_devices[0].format_properties.insert(
993 | vk::VK_FORMAT_B8G8R8A8_UNORM,
994 | vk::VkFormatProperties {
995 | linearTilingFeatures: 0,
996 | optimalTilingFeatures:
997 | vk::VK_FORMAT_FEATURE_COLOR_ATTACHMENT_BIT
998 | | vk::VK_FORMAT_FEATURE_BLIT_SRC_BIT,
999 | bufferFeatures: 0,
1000 | },
1001 | );
1002 |
1003 | let memory_properties =
1004 | &mut fake_vulkan.physical_devices[0].memory_properties;
1005 | memory_properties.memoryTypes[0].propertyFlags =
1006 | vk::VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
1007 | memory_properties.memoryTypeCount = 1;
1008 | fake_vulkan.memory_requirements.memoryTypeBits = 1;
1009 |
1010 | fake_vulkan.set_override();
1011 | let context = Rc::new(Context::new(
1012 | &Requirements::new(),
1013 | None, // device_id
1014 | ).unwrap());
1015 |
1016 | let window = Rc::new(Window::new(
1017 | Rc::clone(&context),
1018 | &Default::default(), // format
1019 | ).unwrap());
1020 |
1021 | queue_errors(&mut fake_vulkan);
1022 |
1023 | let mut logger = Logger::new(None, ptr::null_mut());
1024 |
1025 | let source = Source::from_string(source.to_string());
1026 | let script = Script::load(&Config::new(), &source).unwrap();
1027 |
1028 | let pipeline_set = PipelineSet::new(
1029 | &mut logger,
1030 | Rc::clone(&window),
1031 | &script,
1032 | false, // show_disassembly
1033 | )?;
1034 |
1035 | Ok(TestData { fake_vulkan, window, pipeline_set })
1036 | }
1037 |
1038 | fn new(source: &str) -> Result<TestData, Error> {
1039 | TestData::new_with_errors(source, |_| ())
1040 | }
1041 |
1042 | fn graphics_create_info(
1043 | &mut self,
1044 | pipeline_num: usize,
1045 | ) -> GraphicsPipelineCreateInfo {
1046 | let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];
1047 |
1048 | match &self.fake_vulkan.get_handle(pipeline).data {
1049 | HandleType::Pipeline(PipelineCreateInfo::Graphics(info)) => {
1050 | info.clone()
1051 | },
1052 | handle @ _ => {
1053 | unreachable!("unexpected handle type: {:?}", handle)
1054 | },
1055 | }
1056 | }
1057 |
1058 | fn compute_create_info(
1059 | &mut self,
1060 | pipeline_num: usize,
1061 | ) -> vk::VkComputePipelineCreateInfo {
1062 | let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];
1063 |
1064 | match &self.fake_vulkan.get_handle(pipeline).data {
1065 | HandleType::Pipeline(PipelineCreateInfo::Compute(info)) => {
1066 | info.clone()
1067 | },
1068 | handle @ _ => {
1069 | unreachable!("unexpected handle type: {:?}", handle)
1070 | },
1071 | }
1072 | }
1073 |
1074 | fn shader_module_code(
1075 | &mut self,
1076 | stage: shader_stage::Stage,
1077 | ) -> Vec<u32> {
1078 | let module = self.pipeline_set._modules[stage as usize]
1079 | .as_ref()
1080 | .unwrap()
1081 | .handle;
1082 |
1083 | match &self.fake_vulkan.get_handle(module).data {
1084 | HandleType::ShaderModule { code } => code.clone(),
1085 | _ => unreachable!("Unexpected Vulkan handle type"),
1086 | }
1087 | }
1088 |
1089 | fn descriptor_set_layout_bindings(
1090 | &mut self,
1091 | desc_set: usize,
1092 | ) -> Vec<vk::VkDescriptorSetLayoutBinding> {
1093 | let desc_set_layout = self.pipeline_set
1094 | .descriptor_set_layouts()
1095 | [desc_set];
1096 |
1097 | match &self.fake_vulkan.get_handle(desc_set_layout).data {
1098 | HandleType::DescriptorSetLayout { bindings } => {
1099 | bindings.clone()
1100 | },
1101 | _ => unreachable!("Unexpected Vulkan handle type"),
1102 | }
1103 | }
1104 |
1105 | fn pipeline_layout_create_info(&mut self) -> PipelineLayoutCreateInfo {
1106 | let layout = self.pipeline_set.layout();
1107 |
1108 | match &self.fake_vulkan.get_handle(layout).data {
1109 | HandleType::PipelineLayout(create_info) => create_info.clone(),
1110 | _ => unreachable!("Unexpected Vulkan handle type"),
1111 | }
1112 | }
1113 | }
1114 |
1115 | #[test]
1116 | fn base() {
1117 | let mut test_data = TestData::new(
1118 | "[vertex shader passthrough]\n\
1119 | [fragment shader]\n\
1120 | 03 02 23 07\n\
1121 | fe ca fe ca\n\
1122 | [vertex data]\n\
1123 | 0/R32G32_SFLOAT 1/R32_SFLOAT\n\
1124 | -0.5 -0.5 1.52\n\
1125 | 0.5 -0.5 1.55\n\
1126 | [compute shader]\n\
1127 | 03 02 23 07\n\
1128 | ca fe ca fe\n\
1129 | [test]\n\
1130 | ubo 0 1024\n\
1131 | ssbo 1 1024\n\
1132 | compute 1 1 1\n\
1133 | push float 6 42.0\n\
1134 | draw rect 0 0 1 1\n\
1135 | draw arrays TRIANGLE_LIST 0 2\n"
1136 | ).unwrap();
1137 |
1138 | assert_eq!(
1139 | test_data.pipeline_set.stages(),
1140 | vk::VK_SHADER_STAGE_VERTEX_BIT
1141 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT
1142 | | vk::VK_SHADER_STAGE_COMPUTE_BIT,
1143 | );
1144 |
1145 | let descriptor_pool = test_data.pipeline_set.descriptor_pool();
1146 | assert!(matches!(
1147 | test_data.fake_vulkan.get_handle(descriptor_pool.unwrap()).data,
1148 | HandleType::DescriptorPool,
1149 | ));
1150 |
1151 | assert_eq!(test_data.pipeline_set.pipelines().len(), 3);
1152 |
1153 | let compute_create_info = test_data.compute_create_info(0);
1154 | assert_eq!(compute_create_info.sType, vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO);
1155 |
1156 | let create_data = test_data.graphics_create_info(1);
1157 | let create_info = &create_data.create_info;
1158 |
1159 | assert_eq!(
1160 | create_info.flags,
1161 | vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT,
1162 | );
1163 | assert_eq!(create_info.stageCount, 2);
1164 | assert_eq!(create_info.layout, test_data.pipeline_set.layout());
1165 | assert_eq!(
1166 | create_info.renderPass,
1167 | test_data.window.render_passes()[0]
1168 | );
1169 | assert_eq!(create_info.basePipelineHandle, vk::null_handle());
1170 |
1171 | assert_eq!(create_data.bindings.len(), 1);
1172 | assert_eq!(create_data.bindings[0].binding, 0);
1173 | assert_eq!(
1174 | create_data.bindings[0].stride as usize,
1175 | mem::size_of::<RectangleVertex>(),
1176 | );
1177 | assert_eq!(
1178 | create_data.bindings[0].inputRate,
1179 | vk::VK_VERTEX_INPUT_RATE_VERTEX,
1180 | );
1181 |
1182 | assert_eq!(create_data.attribs.len(), 1);
1183 | assert_eq!(create_data.attribs[0].location, 0);
1184 | assert_eq!(create_data.attribs[0].binding, 0);
1185 | assert_eq!(
1186 | create_data.attribs[0].format,
1187 | vk::VK_FORMAT_R32G32B32_SFLOAT
1188 | );
1189 | assert_eq!(create_data.attribs[0].offset, 0);
1190 |
1191 | let create_data = test_data.graphics_create_info(2);
1192 | let create_info = &create_data.create_info;
1193 |
1194 | assert_eq!(
1195 | create_info.flags,
1196 | vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT,
1197 | );
1198 | assert_eq!(create_info.stageCount, 2);
1199 | assert_eq!(create_info.layout, test_data.pipeline_set.layout());
1200 | assert_eq!(
1201 | create_info.renderPass,
1202 | test_data.window.render_passes()[0]
1203 | );
1204 | assert_eq!(
1205 | create_info.basePipelineHandle,
1206 | test_data.pipeline_set.pipelines()[1]
1207 | );
1208 |
1209 | assert_eq!(create_data.bindings.len(), 1);
1210 | assert_eq!(create_data.bindings[0].binding, 0);
1211 | assert_eq!(
1212 | create_data.bindings[0].stride as usize,
1213 | mem::size_of::<f32>() * 3,
1214 | );
1215 | assert_eq!(
1216 | create_data.bindings[0].inputRate,
1217 | vk::VK_VERTEX_INPUT_RATE_VERTEX,
1218 | );
1219 |
1220 | assert_eq!(create_data.attribs.len(), 2);
1221 | assert_eq!(create_data.attribs[0].location, 0);
1222 | assert_eq!(create_data.attribs[0].binding, 0);
1223 | assert_eq!(
1224 | create_data.attribs[0].format,
1225 | vk::VK_FORMAT_R32G32_SFLOAT
1226 | );
1227 | assert_eq!(create_data.attribs[0].offset, 0);
1228 | assert_eq!(create_data.attribs[1].location, 1);
1229 | assert_eq!(create_data.attribs[1].binding, 0);
1230 | assert_eq!(
1231 | create_data.attribs[1].format,
1232 | vk::VK_FORMAT_R32_SFLOAT
1233 | );
1234 | assert_eq!(create_data.attribs[1].offset, 8);
1235 |
1236 | let code = test_data.shader_module_code(shader_stage::Stage::Vertex);
1237 | assert!(code.len() > 2);
1238 |
1239 | let code = test_data.shader_module_code(shader_stage::Stage::Fragment);
1240 | assert_eq!(&code, &[0x07230203, 0xcafecafe]);
1241 |
1242 | let code = test_data.shader_module_code(shader_stage::Stage::Compute);
1243 | assert_eq!(&code, &[0x07230203, 0xfecafeca]);
1244 |
1245 | assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 1);
1246 | let bindings = test_data.descriptor_set_layout_bindings(0);
1247 | assert_eq!(bindings.len(), 2);
1248 |
1249 | assert_eq!(bindings[0].binding, 0);
1250 | assert_eq!(
1251 | bindings[0].descriptorType,
1252 | vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
1253 | );
1254 | assert_eq!(bindings[0].descriptorCount, 1);
1255 | assert_eq!(
1256 | bindings[0].stageFlags,
1257 | vk::VK_SHADER_STAGE_VERTEX_BIT
1258 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT
1259 | | vk::VK_SHADER_STAGE_COMPUTE_BIT,
1260 | );
1261 |
1262 | assert_eq!(bindings[1].binding, 1);
1263 | assert_eq!(
1264 | bindings[1].descriptorType,
1265 | vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
1266 | );
1267 | assert_eq!(bindings[1].descriptorCount, 1);
1268 | assert_eq!(
1269 | bindings[1].stageFlags,
1270 | vk::VK_SHADER_STAGE_VERTEX_BIT
1271 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT
1272 | | vk::VK_SHADER_STAGE_COMPUTE_BIT,
1273 | );
1274 |
1275 | let layout = test_data.pipeline_layout_create_info();
1276 | assert_eq!(layout.create_info.sType, vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO);
1277 | assert_eq!(layout.push_constant_ranges.len(), 1);
1278 | assert_eq!(
1279 | layout.push_constant_ranges[0].stageFlags,
1280 | vk::VK_SHADER_STAGE_VERTEX_BIT
1281 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT
1282 | | vk::VK_SHADER_STAGE_COMPUTE_BIT,
1283 | );
1284 | assert_eq!(layout.push_constant_ranges[0].offset, 0);
1285 | assert_eq!(
1286 | layout.push_constant_ranges[0].size as usize,
1287 | 6 + mem::size_of::<f32>()
1288 | );
1289 | assert_eq!(
1290 | &layout.layouts,
1291 | test_data.pipeline_set.descriptor_set_layouts()
1292 | );
1293 | }
1294 |
1295 | #[test]
1296 | fn descriptor_set_gaps() {
1297 | let mut test_data = TestData::new(
1298 | "[vertex shader passthrough]\n\
1299 | [fragment shader]\n\
1300 | 03 02 23 07\n\
1301 | fe ca fe ca\n\
1302 | [test]\n\
1303 | ubo 0 1024\n\
1304 | ssbo 2:1 1024\n\
1305 | ssbo 3:1 1024\n\
1306 | ubo 5:5 1024\n\
1307 | draw rect 0 0 1 1\n"
1308 | ).unwrap();
1309 |
1310 | assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 6);
1311 |
1312 | let bindings = test_data.descriptor_set_layout_bindings(0);
1313 | assert_eq!(bindings.len(), 1);
1314 |
1315 | assert_eq!(bindings[0].binding, 0);
1316 | assert_eq!(
1317 | bindings[0].descriptorType,
1318 | vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
1319 | );
1320 | assert_eq!(bindings[0].descriptorCount, 1);
1321 | assert_eq!(
1322 | bindings[0].stageFlags,
1323 | vk::VK_SHADER_STAGE_VERTEX_BIT
1324 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
1325 | );
1326 |
1327 | assert_eq!(test_data.descriptor_set_layout_bindings(1).len(), 0);
1328 |
1329 | let bindings = test_data.descriptor_set_layout_bindings(2);
1330 | assert_eq!(bindings.len(), 1);
1331 |
1332 | assert_eq!(bindings[0].binding, 1);
1333 | assert_eq!(
1334 | bindings[0].descriptorType,
1335 | vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
1336 | );
1337 | assert_eq!(bindings[0].descriptorCount, 1);
1338 | assert_eq!(
1339 | bindings[0].stageFlags,
1340 | vk::VK_SHADER_STAGE_VERTEX_BIT
1341 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
1342 | );
1343 |
1344 | let bindings = test_data.descriptor_set_layout_bindings(3);
1345 | assert_eq!(bindings.len(), 1);
1346 |
1347 | assert_eq!(bindings[0].binding, 1);
1348 | assert_eq!(
1349 | bindings[0].descriptorType,
1350 | vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
1351 | );
1352 | assert_eq!(bindings[0].descriptorCount, 1);
1353 | assert_eq!(
1354 | bindings[0].stageFlags,
1355 | vk::VK_SHADER_STAGE_VERTEX_BIT
1356 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
1357 | );
1358 |
1359 | assert_eq!(test_data.descriptor_set_layout_bindings(4).len(), 0);
1360 |
1361 | let bindings = test_data.descriptor_set_layout_bindings(5);
1362 | assert_eq!(bindings.len(), 1);
1363 |
1364 | assert_eq!(bindings[0].binding, 5);
1365 | assert_eq!(
1366 | bindings[0].descriptorType,
1367 | vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
1368 | );
1369 | assert_eq!(bindings[0].descriptorCount, 1);
1370 | assert_eq!(
1371 | bindings[0].stageFlags,
1372 | vk::VK_SHADER_STAGE_VERTEX_BIT
1373 | | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
1374 | );
1375 | }
1376 |
1377 | #[test]
1378 | fn compile_error() {
1379 | let error = TestData::new(
1380 | "[fragment shader]\n\
1381 | 12 34 56 78\n"
1382 | ).unwrap_err();
1383 |
1384 | assert_eq!(
1385 | &error.to_string(),
1386 | "The compiler or assembler generated an invalid SPIR-V binary"
1387 | );
1388 | }
1389 |
1390 | #[test]
1391 | fn create_errors() {
1392 | let commands = [
1393 | "vkCreatePipelineCache",
1394 | "vkCreateDescriptorPool",
1395 | "vkCreateDescriptorSetLayout",
1396 | "vkCreatePipelineLayout",
1397 | ];
1398 | for command in commands {
1399 | let error = TestData::new_with_errors(
1400 | "[test]\n\
1401 | ubo 0 10\n\
1402 | draw rect 0 0 1 1\n",
1403 | |fake_vulkan| fake_vulkan.queue_result(
1404 | command.to_string(),
1405 | vk::VK_ERROR_UNKNOWN,
1406 | ),
1407 | ).unwrap_err();
1408 |
1409 | assert_eq!(
1410 | error.to_string(),
1411 | format!("{} failed", command),
1412 | );
1413 | }
1414 | }
1415 |
1416 | #[test]
1417 | fn create_pipeline_errors() {
1418 | let commands = [
1419 | "vkCreateGraphicsPipelines",
1420 | "vkCreateComputePipelines",
1421 | ];
1422 | for command in commands {
1423 | let error = TestData::new_with_errors(
1424 | "[compute shader]\n\
1425 | 03 02 23 07\n\
1426 | ca fe ca fe\n\
1427 | [test]\n\
1428 | compute 1 1 1\n\
1429 | draw rect 0 0 1 1\n",
1430 | |fake_vulkan| fake_vulkan.queue_result(
1431 | command.to_string(),
1432 | vk::VK_ERROR_UNKNOWN,
1433 | ),
1434 | ).unwrap_err();
1435 |
1436 | assert_eq!(
1437 | &error.to_string(),
1438 | "Pipeline creation function failed",
1439 | );
1440 | }
1441 | }
1442 |
1443 | #[test]
1444 | fn no_buffers() {
1445 | let test_data = TestData::new("").unwrap();
1446 |
1447 | assert!(test_data.pipeline_set.descriptor_pool().is_none());
1448 | assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 0);
1449 | }
1450 | }
1451 |
```
--------------------------------------------------------------------------------
/vkrunner/vkrunner/slot.rs:
--------------------------------------------------------------------------------
```rust
1 | // vkrunner
2 | //
3 | // Copyright (C) 2018 Intel Corporation
4 | // Copyright 2023 Neil Roberts
5 | //
6 | // Permission is hereby granted, free of charge, to any person obtaining a
7 | // copy of this software and associated documentation files (the "Software"),
8 | // to deal in the Software without restriction, including without limitation
9 | // the rights to use, copy, modify, merge, publish, distribute, sublicense,
10 | // and/or sell copies of the Software, and to permit persons to whom the
11 | // Software is furnished to do so, subject to the following conditions:
12 | //
13 | // The above copyright notice and this permission notice (including the next
14 | // paragraph) shall be included in all copies or substantial portions of the
15 | // Software.
16 | //
17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
20 | // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | // DEALINGS IN THE SOFTWARE.
24 |
25 | //! Contains utilities for accessing values stored in a buffer
26 | //! according to the GLSL layout rules. These are decribed in detail
27 | //! in the OpenGL specs here:
28 | //! <https://registry.khronos.org/OpenGL/specs/gl/glspec45.core.pdf#page=159>
29 |
30 | use crate::util;
31 | use crate::tolerance::Tolerance;
32 | use crate::half_float;
33 | use std::mem;
34 | use std::convert::TryInto;
35 | use std::fmt;
36 |
37 | /// Describes which layout standard is being used. The only difference
38 | /// is that with `Std140` the offset between members of an
39 | /// array or between elements of the minor axis of a matrix is rounded
40 | /// up to be a multiple of 16 bytes.
41 | #[derive(Debug, Clone, Copy, Eq, PartialEq)]
42 | pub enum LayoutStd {
43 | Std140,
44 | Std430,
45 | }
46 |
47 | /// For matrix types, describes which axis is the major one. The
48 | /// elements of the major axis are stored consecutively in memory. For
49 | /// example, if the numbers indicate the order in memory of the
50 | /// components of a matrix, the two orderings would look like this:
51 | ///
52 | /// ```text
53 | /// Column major Row major
54 | /// [0 3 6] [0 1 2]
55 | /// [1 4 7] [3 4 5]
56 | /// [2 5 8] [6 7 8]
57 | /// ```
58 | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
59 | pub enum MajorAxis {
60 | Column,
61 | Row
62 | }
63 |
64 | /// Combines the [MajorAxis] and the [LayoutStd] to provide a complete
65 | /// description of the layout options used for accessing the
66 | /// components of a type.
67 | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
68 | pub struct Layout {
69 | pub std: LayoutStd,
70 | pub major: MajorAxis,
71 | }
72 |
73 | /// An enum representing all of the types that can be stored in a
74 | /// slot. These correspond to the types in GLSL.
75 | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
76 | pub enum Type {
77 | Int = 0,
78 | UInt,
79 | Int8,
80 | UInt8,
81 | Int16,
82 | UInt16,
83 | Int64,
84 | UInt64,
85 | Float16,
86 | Float,
87 | Double,
88 | F16Vec2,
89 | F16Vec3,
90 | F16Vec4,
91 | Vec2,
92 | Vec3,
93 | Vec4,
94 | DVec2,
95 | DVec3,
96 | DVec4,
97 | IVec2,
98 | IVec3,
99 | IVec4,
100 | UVec2,
101 | UVec3,
102 | UVec4,
103 | I8Vec2,
104 | I8Vec3,
105 | I8Vec4,
106 | U8Vec2,
107 | U8Vec3,
108 | U8Vec4,
109 | I16Vec2,
110 | I16Vec3,
111 | I16Vec4,
112 | U16Vec2,
113 | U16Vec3,
114 | U16Vec4,
115 | I64Vec2,
116 | I64Vec3,
117 | I64Vec4,
118 | U64Vec2,
119 | U64Vec3,
120 | U64Vec4,
121 | Mat2,
122 | Mat2x3,
123 | Mat2x4,
124 | Mat3x2,
125 | Mat3,
126 | Mat3x4,
127 | Mat4x2,
128 | Mat4x3,
129 | Mat4,
130 | DMat2,
131 | DMat2x3,
132 | DMat2x4,
133 | DMat3x2,
134 | DMat3,
135 | DMat3x4,
136 | DMat4x2,
137 | DMat4x3,
138 | DMat4,
139 | }
140 |
141 | /// The type of a component of the types in [Type].
142 | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
143 | pub enum BaseType {
144 | Int,
145 | UInt,
146 | Int8,
147 | UInt8,
148 | Int16,
149 | UInt16,
150 | Int64,
151 | UInt64,
152 | Float16,
153 | Float,
154 | Double,
155 | }
156 |
157 | /// A base type combined with a slice of bytes. This implements
158 | /// [Display](std::fmt::Display) so it can be used to extract a base
159 | /// type from a byte array and present it as a human-readable string.
160 | pub struct BaseTypeInSlice<'a> {
161 | base_type: BaseType,
162 | slice: &'a [u8],
163 | }
164 |
165 | /// A type of comparison that can be used to compare values stored in
166 | /// a slot with the chosen criteria. Use the
167 | /// [compare](Comparison::compare) method to perform the comparison.
168 | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
169 | pub enum Comparison {
170 | /// The comparison passes if the values are exactly equal.
171 | Equal,
172 | /// The comparison passes if the floating-point values are
173 | /// approximately equal using the given [Tolerance] settings. For
174 | /// integer types this is the same as [Equal](Comparison::Equal).
175 | FuzzyEqual,
176 | /// The comparison passes if the values are different.
177 | NotEqual,
178 | /// The comparison passes if all of the values in the first slot
179 | /// are less than the values in the second slot.
180 | Less,
181 | /// The comparison passes if all of the values in the first slot
182 | /// are greater than or equal to the values in the second slot.
183 | GreaterEqual,
184 | /// The comparison passes if all of the values in the first slot
185 | /// are greater than the values in the second slot.
186 | Greater,
187 | /// The comparison passes if all of the values in the first slot
188 | /// are less than or equal to the values in the second slot.
189 | LessEqual,
190 | }
191 |
192 | #[derive(Debug)]
193 | struct TypeInfo {
194 | base_type: BaseType,
195 | columns: usize,
196 | rows: usize,
197 | }
198 |
199 | static TYPE_INFOS: [TypeInfo; 62] = [
200 | TypeInfo { base_type: BaseType::Int, columns: 1, rows: 1 }, // Int
201 | TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 1 }, // UInt
202 | TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 1 }, // Int8
203 | TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 1 }, // UInt8
204 | TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 1 }, // Int16
205 | TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 1 }, // UInt16
206 | TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 1 }, // Int64
207 | TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 1 }, // UInt64
208 | TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 1 }, // Float16
209 | TypeInfo { base_type: BaseType::Float, columns: 1, rows: 1 }, // Float
210 | TypeInfo { base_type: BaseType::Double, columns: 1, rows: 1 }, // Double
211 | TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 2 }, // F16Vec2
212 | TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 3 }, // F16Vec3
213 | TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 4 }, // F16Vec4
214 | TypeInfo { base_type: BaseType::Float, columns: 1, rows: 2 }, // Vec2
215 | TypeInfo { base_type: BaseType::Float, columns: 1, rows: 3 }, // Vec3
216 | TypeInfo { base_type: BaseType::Float, columns: 1, rows: 4 }, // Vec4
217 | TypeInfo { base_type: BaseType::Double, columns: 1, rows: 2 }, // DVec2
218 | TypeInfo { base_type: BaseType::Double, columns: 1, rows: 3 }, // DVec3
219 | TypeInfo { base_type: BaseType::Double, columns: 1, rows: 4 }, // DVec4
220 | TypeInfo { base_type: BaseType::Int, columns: 1, rows: 2 }, // IVec2
221 | TypeInfo { base_type: BaseType::Int, columns: 1, rows: 3 }, // IVec3
222 | TypeInfo { base_type: BaseType::Int, columns: 1, rows: 4 }, // IVec4
223 | TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 2 }, // UVec2
224 | TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 3 }, // UVec3
225 | TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 4 }, // UVec4
226 | TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 2 }, // I8Vec2
227 | TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 3 }, // I8Vec3
228 | TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 4 }, // I8Vec4
229 | TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 2 }, // U8Vec2
230 | TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 3 }, // U8Vec3
231 | TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 4 }, // U8Vec4
232 | TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 2 }, // I16Vec2
233 | TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 3 }, // I16Vec3
234 | TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 4 }, // I16Vec4
235 | TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 2 }, // U16Vec2
236 | TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 3 }, // U16Vec3
237 | TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 4 }, // U16Vec4
238 | TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 2 }, // I64Vec2
239 | TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 3 }, // I64Vec3
240 | TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 4 }, // I64Vec4
241 | TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 2 }, // U64Vec2
242 | TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 3 }, // U64Vec3
243 | TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 4 }, // U64Vec4
244 | TypeInfo { base_type: BaseType::Float, columns: 2, rows: 2 }, // Mat2
245 | TypeInfo { base_type: BaseType::Float, columns: 2, rows: 3 }, // Mat2x3
246 | TypeInfo { base_type: BaseType::Float, columns: 2, rows: 4 }, // Mat2x4
247 | TypeInfo { base_type: BaseType::Float, columns: 3, rows: 2 }, // Mat3x2
248 | TypeInfo { base_type: BaseType::Float, columns: 3, rows: 3 }, // Mat3
249 | TypeInfo { base_type: BaseType::Float, columns: 3, rows: 4 }, // Mat3x4
250 | TypeInfo { base_type: BaseType::Float, columns: 4, rows: 2 }, // Mat4x2
251 | TypeInfo { base_type: BaseType::Float, columns: 4, rows: 3 }, // Mat4x3
252 | TypeInfo { base_type: BaseType::Float, columns: 4, rows: 4 }, // Mat4
253 | TypeInfo { base_type: BaseType::Double, columns: 2, rows: 2 }, // DMat2
254 | TypeInfo { base_type: BaseType::Double, columns: 2, rows: 3 }, // DMat2x3
255 | TypeInfo { base_type: BaseType::Double, columns: 2, rows: 4 }, // DMat2x4
256 | TypeInfo { base_type: BaseType::Double, columns: 3, rows: 2 }, // DMat3x2
257 | TypeInfo { base_type: BaseType::Double, columns: 3, rows: 3 }, // DMat3
258 | TypeInfo { base_type: BaseType::Double, columns: 3, rows: 4 }, // DMat3x4
259 | TypeInfo { base_type: BaseType::Double, columns: 4, rows: 2 }, // DMat4x2
260 | TypeInfo { base_type: BaseType::Double, columns: 4, rows: 3 }, // DMat4x3
261 | TypeInfo { base_type: BaseType::Double, columns: 4, rows: 4 }, // DMat4
262 | ];
263 |
264 | // Mapping from GLSL type name to slot type. Sorted alphabetically so
265 | // we can do a binary search.
266 | static GLSL_TYPE_NAMES: [(&'static str, Type); 68] = [
267 | ("dmat2", Type::DMat2),
268 | ("dmat2x2", Type::DMat2),
269 | ("dmat2x3", Type::DMat2x3),
270 | ("dmat2x4", Type::DMat2x4),
271 | ("dmat3", Type::DMat3),
272 | ("dmat3x2", Type::DMat3x2),
273 | ("dmat3x3", Type::DMat3),
274 | ("dmat3x4", Type::DMat3x4),
275 | ("dmat4", Type::DMat4),
276 | ("dmat4x2", Type::DMat4x2),
277 | ("dmat4x3", Type::DMat4x3),
278 | ("dmat4x4", Type::DMat4),
279 | ("double", Type::Double),
280 | ("dvec2", Type::DVec2),
281 | ("dvec3", Type::DVec3),
282 | ("dvec4", Type::DVec4),
283 | ("f16vec2", Type::F16Vec2),
284 | ("f16vec3", Type::F16Vec3),
285 | ("f16vec4", Type::F16Vec4),
286 | ("float", Type::Float),
287 | ("float16_t", Type::Float16),
288 | ("i16vec2", Type::I16Vec2),
289 | ("i16vec3", Type::I16Vec3),
290 | ("i16vec4", Type::I16Vec4),
291 | ("i64vec2", Type::I64Vec2),
292 | ("i64vec3", Type::I64Vec3),
293 | ("i64vec4", Type::I64Vec4),
294 | ("i8vec2", Type::I8Vec2),
295 | ("i8vec3", Type::I8Vec3),
296 | ("i8vec4", Type::I8Vec4),
297 | ("int", Type::Int),
298 | ("int16_t", Type::Int16),
299 | ("int64_t", Type::Int64),
300 | ("int8_t", Type::Int8),
301 | ("ivec2", Type::IVec2),
302 | ("ivec3", Type::IVec3),
303 | ("ivec4", Type::IVec4),
304 | ("mat2", Type::Mat2),
305 | ("mat2x2", Type::Mat2),
306 | ("mat2x3", Type::Mat2x3),
307 | ("mat2x4", Type::Mat2x4),
308 | ("mat3", Type::Mat3),
309 | ("mat3x2", Type::Mat3x2),
310 | ("mat3x3", Type::Mat3),
311 | ("mat3x4", Type::Mat3x4),
312 | ("mat4", Type::Mat4),
313 | ("mat4x2", Type::Mat4x2),
314 | ("mat4x3", Type::Mat4x3),
315 | ("mat4x4", Type::Mat4),
316 | ("u16vec2", Type::U16Vec2),
317 | ("u16vec3", Type::U16Vec3),
318 | ("u16vec4", Type::U16Vec4),
319 | ("u64vec2", Type::U64Vec2),
320 | ("u64vec3", Type::U64Vec3),
321 | ("u64vec4", Type::U64Vec4),
322 | ("u8vec2", Type::U8Vec2),
323 | ("u8vec3", Type::U8Vec3),
324 | ("u8vec4", Type::U8Vec4),
325 | ("uint", Type::UInt),
326 | ("uint16_t", Type::UInt16),
327 | ("uint64_t", Type::UInt64),
328 | ("uint8_t", Type::UInt8),
329 | ("uvec2", Type::UVec2),
330 | ("uvec3", Type::UVec3),
331 | ("uvec4", Type::UVec4),
332 | ("vec2", Type::Vec2),
333 | ("vec3", Type::Vec3),
334 | ("vec4", Type::Vec4),
335 | ];
336 |
337 | // Mapping from operator name to Comparison. Sorted by ASCII values so
338 | // we can do a binary search.
339 | static COMPARISON_NAMES: [(&'static str, Comparison); 7] = [
340 | ("!=", Comparison::NotEqual),
341 | ("<", Comparison::Less),
342 | ("<=", Comparison::LessEqual),
343 | ("==", Comparison::Equal),
344 | (">", Comparison::Greater),
345 | (">=", Comparison::GreaterEqual),
346 | ("~=", Comparison::FuzzyEqual),
347 | ];
348 |
349 | impl BaseType {
350 | /// Returns the size in bytes of a variable of this base type.
351 | pub fn size(self) -> usize {
352 | match self {
353 | BaseType::Int => mem::size_of::<i32>(),
354 | BaseType::UInt => mem::size_of::<u32>(),
355 | BaseType::Int8 => mem::size_of::<i8>(),
356 | BaseType::UInt8 => mem::size_of::<u8>(),
357 | BaseType::Int16 => mem::size_of::<i16>(),
358 | BaseType::UInt16 => mem::size_of::<u16>(),
359 | BaseType::Int64 => mem::size_of::<i64>(),
360 | BaseType::UInt64 => mem::size_of::<u64>(),
361 | BaseType::Float16 => mem::size_of::<u16>(),
362 | BaseType::Float => mem::size_of::<u32>(),
363 | BaseType::Double => mem::size_of::<u64>(),
364 | }
365 | }
366 | }
367 |
368 | impl<'a> BaseTypeInSlice<'a> {
369 | /// Combines the given base type and slice into a
370 | /// `BaseTypeInSlice` so that it can be displayed with the
371 | /// [Display](std::fmt::Display) trait. The slice is expected to
372 | /// contain a representation of the base type in native-endian
373 | /// order. The slice needs to have the same size as
374 | /// `base_type.size()` or the constructor will panic.
375 | pub fn new(base_type: BaseType, slice: &'a [u8]) -> BaseTypeInSlice<'a> {
376 | assert_eq!(slice.len(), base_type.size());
377 |
378 | BaseTypeInSlice { base_type, slice }
379 | }
380 | }
381 |
382 | impl<'a> fmt::Display for BaseTypeInSlice<'a> {
383 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
384 | match self.base_type {
385 | BaseType::Int => {
386 | let v = i32::from_ne_bytes(self.slice.try_into().unwrap());
387 | write!(f, "{}", v)
388 | },
389 | BaseType::UInt => {
390 | let v = u32::from_ne_bytes(self.slice.try_into().unwrap());
391 | write!(f, "{}", v)
392 | },
393 | BaseType::Int8 => {
394 | let v = i8::from_ne_bytes(self.slice.try_into().unwrap());
395 | write!(f, "{}", v)
396 | },
397 | BaseType::UInt8 => {
398 | let v = u8::from_ne_bytes(self.slice.try_into().unwrap());
399 | write!(f, "{}", v)
400 | },
401 | BaseType::Int16 => {
402 | let v = i16::from_ne_bytes(self.slice.try_into().unwrap());
403 | write!(f, "{}", v)
404 | },
405 | BaseType::UInt16 => {
406 | let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
407 | write!(f, "{}", v)
408 | },
409 | BaseType::Int64 => {
410 | let v = i64::from_ne_bytes(self.slice.try_into().unwrap());
411 | write!(f, "{}", v)
412 | },
413 | BaseType::UInt64 => {
414 | let v = u64::from_ne_bytes(self.slice.try_into().unwrap());
415 | write!(f, "{}", v)
416 | },
417 | BaseType::Float16 => {
418 | let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
419 | let v = half_float::to_f64(v);
420 | write!(f, "{}", v)
421 | },
422 | BaseType::Float => {
423 | let v = f32::from_ne_bytes(self.slice.try_into().unwrap());
424 | write!(f, "{}", v)
425 | },
426 | BaseType::Double => {
427 | let v = f64::from_ne_bytes(self.slice.try_into().unwrap());
428 | write!(f, "{}", v)
429 | },
430 | }
431 | }
432 | }
433 |
434 | impl Type {
435 | /// Returns the type of the components of this slot type.
436 | #[inline]
437 | pub fn base_type(self) -> BaseType {
438 | TYPE_INFOS[self as usize].base_type
439 | }
440 |
441 | /// Returns the number of rows in this slot type. For primitive types
442 | /// this will be 1, for vectors it will be the size of the vector
443 | /// and for matrix types it will be the number of rows.
444 | #[inline]
445 | pub fn rows(self) -> usize {
446 | TYPE_INFOS[self as usize].rows
447 | }
448 |
449 | /// Return the number of columns in this slot type. For non-matrix
450 | /// types this will be 1.
451 | #[inline]
452 | pub fn columns(self) -> usize {
453 | TYPE_INFOS[self as usize].columns
454 | }
455 |
456 | /// Returns whether the type is a matrix type
457 | #[inline]
458 | pub fn is_matrix(self) -> bool {
459 | self.columns() > 1
460 | }
461 |
462 | // Returns the effective major axis setting that should be used
463 | // when calculating offsets. The setting should only affect matrix
464 | // types.
465 | fn effective_major_axis(self, layout: Layout) -> MajorAxis {
466 | if self.is_matrix() {
467 | layout.major
468 | } else {
469 | MajorAxis::Column
470 | }
471 | }
472 |
473 | // Return a tuple containing the size of the major and minor axes
474 | // (in that order) given the type and layout
475 | fn major_minor(self, layout: Layout) -> (usize, usize) {
476 | match self.effective_major_axis(layout) {
477 | MajorAxis::Column => (self.columns(), self.rows()),
478 | MajorAxis::Row => (self.rows(), self.columns()),
479 | }
480 | }
481 |
482 | /// Returns the matrix stride of the slot type. This is the offset
483 | /// in bytes between consecutive components of the minor axis. For
484 | /// example, in a column-major layout, the matrix stride of a vec4
485 | /// will be 16, because to move to the next column you need to add
486 | /// an offset of 16 bytes to the address of the previous column.
487 | pub fn matrix_stride(self, layout: Layout) -> usize {
488 | let component_size = self.base_type().size();
489 | let (_major, minor) = self.major_minor(layout);
490 |
491 | let base_stride = if minor == 3 {
492 | component_size * 4
493 | } else {
494 | component_size * minor
495 | };
496 |
497 | match layout.std {
498 | LayoutStd::Std140 => {
499 | // According to std140 the size is rounded up to a vec4
500 | util::align(base_stride, 16)
501 | },
502 | LayoutStd::Std430 => base_stride,
503 | }
504 | }
505 |
506 | /// Returns the offset between members of the elements of an array
507 | /// of this type of slot. There may be padding between the members
508 | /// to fulfill the alignment criteria of the layout.
509 | pub fn array_stride(self, layout: Layout) -> usize {
510 | let matrix_stride = self.matrix_stride(layout);
511 | let (major, _minor) = self.major_minor(layout);
512 |
513 | matrix_stride * major
514 | }
515 |
516 | /// Returns the size of the slot type. Note that unlike
517 | /// [array_stride](Type::array_stride), this won’t include the
518 | /// padding to align the values in an array according to the GLSL
519 | /// layout rules. It will however always have a size that is a
520 | /// multiple of the size of the base type so it is suitable as an
521 | /// array stride for packed values stored internally that aren’t
522 | /// passed to Vulkan.
523 | pub fn size(self, layout: Layout) -> usize {
524 | let matrix_stride = self.matrix_stride(layout);
525 | let base_size = self.base_type().size();
526 | let (major, minor) = self.major_minor(layout);
527 |
528 | (major - 1) * matrix_stride + base_size * minor
529 | }
530 |
531 | /// Returns an iterator that iterates over the offsets of the
532 | /// components of this slot. This can be used to extract the
533 | /// values out of an array of bytes containing the slot. The
534 | /// offsets are returned in column-major order (ie, all of the
535 | /// offsets for a row are returned before moving to the next
536 | /// column), but the offsets are calculated to take into account
537 | /// the major axis setting in the layout.
538 | pub fn offsets(self, layout: Layout) -> OffsetsIter {
539 | OffsetsIter::new(self, layout)
540 | }
541 |
542 | /// Returns a [Type] given the GLSL type name, for example `dmat2`
543 | /// or `i64vec3`. `None` will be returned if the type name isn’t
544 | /// recognised.
545 | pub fn from_glsl_type(type_name: &str) -> Option<Type> {
546 | match GLSL_TYPE_NAMES.binary_search_by(
547 | |&(name, _)| name.cmp(type_name)
548 | ) {
549 | Ok(pos) => Some(GLSL_TYPE_NAMES[pos].1),
550 | Err(_) => None,
551 | }
552 | }
553 | }
554 |
555 | /// Iterator over the offsets into an array to extract the components
556 | /// of a slot. Created by [Type::offsets].
557 | #[derive(Debug, Clone)]
558 | pub struct OffsetsIter {
559 | slot_type: Type,
560 | row: usize,
561 | column: usize,
562 | offset: usize,
563 | column_offset: isize,
564 | row_offset: isize,
565 | }
566 |
567 | impl OffsetsIter {
568 | fn new(slot_type: Type, layout: Layout) -> OffsetsIter {
569 | let column_offset;
570 | let row_offset;
571 | let stride = slot_type.matrix_stride(layout);
572 | let base_size = slot_type.base_type().size();
573 |
574 | match slot_type.effective_major_axis(layout) {
575 | MajorAxis::Column => {
576 | column_offset =
577 | (stride - base_size * slot_type.rows()) as isize;
578 | row_offset = base_size as isize;
579 | },
580 | MajorAxis::Row => {
581 | column_offset =
582 | base_size as isize - (stride * slot_type.rows()) as isize;
583 | row_offset = stride as isize;
584 | },
585 | }
586 |
587 | OffsetsIter {
588 | slot_type,
589 | row: 0,
590 | column: 0,
591 | offset: 0,
592 | column_offset,
593 | row_offset,
594 | }
595 | }
596 | }
597 |
598 | impl Iterator for OffsetsIter {
599 | type Item = usize;
600 |
601 | fn next(&mut self) -> Option<usize> {
602 | if self.column >= self.slot_type.columns() {
603 | return None;
604 | }
605 |
606 | let result = self.offset;
607 |
608 | self.offset = ((self.offset as isize) + self.row_offset) as usize;
609 | self.row += 1;
610 |
611 | if self.row >= self.slot_type.rows() {
612 | self.row = 0;
613 | self.column += 1;
614 | self.offset =
615 | ((self.offset as isize) + self.column_offset) as usize;
616 | }
617 |
618 | Some(result)
619 | }
620 |
621 | fn size_hint(&self) -> (usize, Option<usize>) {
622 | let n_columns = self.slot_type.columns();
623 |
624 | let size = if self.column >= n_columns {
625 | 0
626 | } else {
627 | (self.slot_type.rows() - self.row)
628 | + ((n_columns - 1 - self.column) * self.slot_type.rows())
629 | };
630 |
631 | (size, Some(size))
632 | }
633 |
634 | fn count(self) -> usize {
635 | self.size_hint().0
636 | }
637 | }
638 |
639 | impl Comparison {
640 | fn compare_without_fuzzy<T: PartialOrd>(self, a: T, b: T) -> bool {
641 | match self {
642 | Comparison::Equal | Comparison::FuzzyEqual => a.eq(&b),
643 | Comparison::NotEqual => a.ne(&b),
644 | Comparison::Less => a.lt(&b),
645 | Comparison::GreaterEqual => a.ge(&b),
646 | Comparison::Greater => a.gt(&b),
647 | Comparison::LessEqual => a.le(&b),
648 | }
649 | }
650 |
651 | fn compare_value(
652 | self,
653 | tolerance: &Tolerance,
654 | component: usize,
655 | base_type: BaseType,
656 | a: &[u8],
657 | b: &[u8]
658 | ) -> bool {
659 | match base_type {
660 | BaseType::Int => {
661 | let a = i32::from_ne_bytes(a.try_into().unwrap());
662 | let b = i32::from_ne_bytes(b.try_into().unwrap());
663 | self.compare_without_fuzzy(a, b)
664 | },
665 | BaseType::UInt => {
666 | let a = u32::from_ne_bytes(a.try_into().unwrap());
667 | let b = u32::from_ne_bytes(b.try_into().unwrap());
668 | self.compare_without_fuzzy(a, b)
669 | },
670 | BaseType::Int8 => {
671 | let a = i8::from_ne_bytes(a.try_into().unwrap());
672 | let b = i8::from_ne_bytes(b.try_into().unwrap());
673 | self.compare_without_fuzzy(a, b)
674 | },
675 | BaseType::UInt8 => {
676 | let a = u8::from_ne_bytes(a.try_into().unwrap());
677 | let b = u8::from_ne_bytes(b.try_into().unwrap());
678 | self.compare_without_fuzzy(a, b)
679 | },
680 | BaseType::Int16 => {
681 | let a = i16::from_ne_bytes(a.try_into().unwrap());
682 | let b = i16::from_ne_bytes(b.try_into().unwrap());
683 | self.compare_without_fuzzy(a, b)
684 | },
685 | BaseType::UInt16 => {
686 | let a = u16::from_ne_bytes(a.try_into().unwrap());
687 | let b = u16::from_ne_bytes(b.try_into().unwrap());
688 | self.compare_without_fuzzy(a, b)
689 | },
690 | BaseType::Int64 => {
691 | let a = i64::from_ne_bytes(a.try_into().unwrap());
692 | let b = i64::from_ne_bytes(b.try_into().unwrap());
693 | self.compare_without_fuzzy(a, b)
694 | },
695 | BaseType::UInt64 => {
696 | let a = u64::from_ne_bytes(a.try_into().unwrap());
697 | let b = u64::from_ne_bytes(b.try_into().unwrap());
698 | self.compare_without_fuzzy(a, b)
699 | },
700 | BaseType::Float16 => {
701 | let a = u16::from_ne_bytes(a.try_into().unwrap());
702 | let a = half_float::to_f64(a);
703 | let b = u16::from_ne_bytes(b.try_into().unwrap());
704 | let b = half_float::to_f64(b);
705 |
706 | match self {
707 | Comparison::FuzzyEqual => tolerance.equal(component, a, b),
708 | Comparison::Equal
709 | | Comparison::NotEqual
710 | | Comparison::Less
711 | | Comparison::GreaterEqual
712 | | Comparison::Greater
713 | | Comparison::LessEqual => {
714 | self.compare_without_fuzzy(a, b)
715 | },
716 | }
717 | },
718 | BaseType::Float => {
719 | let a = f32::from_ne_bytes(a.try_into().unwrap());
720 | let b = f32::from_ne_bytes(b.try_into().unwrap());
721 |
722 | match self {
723 | Comparison::FuzzyEqual => {
724 | tolerance.equal(component, a as f64, b as f64)
725 | },
726 | Comparison::Equal
727 | | Comparison::NotEqual
728 | | Comparison::Less
729 | | Comparison::GreaterEqual
730 | | Comparison::Greater
731 | | Comparison::LessEqual => {
732 | self.compare_without_fuzzy(a, b)
733 | },
734 | }
735 | },
736 | BaseType::Double => {
737 | let a = f64::from_ne_bytes(a.try_into().unwrap());
738 | let b = f64::from_ne_bytes(b.try_into().unwrap());
739 |
740 | match self {
741 | Comparison::FuzzyEqual => {
742 | tolerance.equal(component, a, b)
743 | },
744 | Comparison::Equal
745 | | Comparison::NotEqual
746 | | Comparison::Less
747 | | Comparison::GreaterEqual
748 | | Comparison::Greater
749 | | Comparison::LessEqual => {
750 | self.compare_without_fuzzy(a, b)
751 | },
752 | }
753 | },
754 | }
755 | }
756 |
757 | /// Compares the values in `a` with the values in `b` using the
758 | /// chosen criteria. The values are extracted from the two byte
759 | /// slices using the layout and type provided. The tolerance
760 | /// parameter is only used if the comparison is
761 | /// [FuzzyEqual](Comparison::FuzzyEqual).
762 | pub fn compare(
763 | self,
764 | tolerance: &Tolerance,
765 | slot_type: Type,
766 | layout: Layout,
767 | a: &[u8],
768 | b: &[u8],
769 | ) -> bool {
770 | let base_type = slot_type.base_type();
771 | let base_type_size = base_type.size();
772 | let rows = slot_type.rows();
773 |
774 | for (i, offset) in slot_type.offsets(layout).enumerate() {
775 | let a = &a[offset..offset + base_type_size];
776 | let b = &b[offset..offset + base_type_size];
777 |
778 | if !self.compare_value(tolerance, i % rows, base_type, a, b) {
779 | return false;
780 | }
781 | }
782 |
783 | true
784 | }
785 |
786 | /// Get the comparison operator given the name in GLSL like `"<"`
787 | /// or `">="`. It also accepts `"~="` to return the fuzzy equal
788 | /// comparison. If the operator isn’t recognised then it will
789 | /// return `None`.
790 | pub fn from_operator(operator: &str) -> Option<Comparison> {
791 | match COMPARISON_NAMES.binary_search_by(
792 | |&(probe, _value)| probe.cmp(operator)
793 | ) {
794 | Ok(pos) => Some(COMPARISON_NAMES[pos].1),
795 | Err(_) => None,
796 | }
797 | }
798 | }
799 |
800 | #[cfg(test)]
801 | mod test {
802 | use super::*;
803 |
804 | #[test]
805 | fn test_type_info() {
806 | // Test a few types
807 |
808 | // First
809 | assert!(matches!(Type::Int.base_type(), BaseType::Int));
810 | assert_eq!(Type::Int.columns(), 1);
811 | assert_eq!(Type::Int.rows(), 1);
812 | assert!(!Type::Int.is_matrix());
813 |
814 | // Last
815 | assert!(matches!(Type::DMat4.base_type(), BaseType::Double));
816 | assert_eq!(Type::DMat4.columns(), 4);
817 | assert_eq!(Type::DMat4.rows(), 4);
818 | assert!(Type::DMat4.is_matrix());
819 |
820 | // Vector
821 | assert!(matches!(Type::Vec4.base_type(), BaseType::Float));
822 | assert_eq!(Type::Vec4.columns(), 1);
823 | assert_eq!(Type::Vec4.rows(), 4);
824 | assert!(!Type::Vec4.is_matrix());
825 |
826 | // Non-square matrix
827 | assert!(matches!(Type::DMat4x3.base_type(), BaseType::Double));
828 | assert_eq!(Type::DMat4x3.columns(), 4);
829 | assert_eq!(Type::DMat4x3.rows(), 3);
830 | assert!(Type::DMat4x3.is_matrix());
831 | }
832 |
833 | #[test]
834 | fn test_base_type_size() {
835 | assert_eq!(BaseType::Int.size(), 4);
836 | assert_eq!(BaseType::UInt.size(), 4);
837 | assert_eq!(BaseType::Int8.size(), 1);
838 | assert_eq!(BaseType::UInt8.size(), 1);
839 | assert_eq!(BaseType::Int16.size(), 2);
840 | assert_eq!(BaseType::UInt16.size(), 2);
841 | assert_eq!(BaseType::Int64.size(), 8);
842 | assert_eq!(BaseType::UInt64.size(), 8);
843 | assert_eq!(BaseType::Float16.size(), 2);
844 | assert_eq!(BaseType::Float.size(), 4);
845 | assert_eq!(BaseType::Double.size(), 8);
846 | }
847 |
848 | #[test]
849 | fn test_matrix_stride() {
850 | assert_eq!(
851 | Type::Mat4.matrix_stride(
852 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
853 | ),
854 | 16,
855 | );
856 | // In Std140 the vecs along the minor axis are aligned to the
857 | // size of a vec4
858 | assert_eq!(
859 | Type::Mat4x2.matrix_stride(
860 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column }
861 | ),
862 | 16,
863 | );
864 | // In Std430 they are not
865 | assert_eq!(
866 | Type::Mat4x2.matrix_stride(
867 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
868 | ),
869 | 8,
870 | );
871 | // 3-component axis are always aligned to a vec4 regardless of
872 | // the layout standard
873 | assert_eq!(
874 | Type::Mat4x3.matrix_stride(
875 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
876 | ),
877 | 16,
878 | );
879 |
880 | // For the row-major the stride the axes are reversed
881 | assert_eq!(
882 | Type::Mat4x2.matrix_stride(
883 | Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
884 | ),
885 | 16,
886 | );
887 | assert_eq!(
888 | Type::Mat2x4.matrix_stride(
889 | Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
890 | ),
891 | 8,
892 | );
893 |
894 | // For non-matrix types the matrix stride is the array stride
895 | assert_eq!(
896 | Type::Vec3.matrix_stride(
897 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
898 | ),
899 | 16,
900 | );
901 | assert_eq!(
902 | Type::Float.matrix_stride(
903 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
904 | ),
905 | 4,
906 | );
907 | // Row-major layout does not affect vectors
908 | assert_eq!(
909 | Type::Vec2.matrix_stride(
910 | Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
911 | ),
912 | 8,
913 | );
914 | }
915 |
916 | #[test]
917 | fn test_array_stride() {
918 | assert_eq!(
919 | Type::UInt8.array_stride(
920 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
921 | ),
922 | 1,
923 | );
924 | assert_eq!(
925 | Type::I8Vec2.array_stride(
926 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
927 | ),
928 | 2,
929 | );
930 | assert_eq!(
931 | Type::I8Vec3.array_stride(
932 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
933 | ),
934 | 4,
935 | );
936 | assert_eq!(
937 | Type::Mat4x3.array_stride(
938 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
939 | ),
940 | 64,
941 | );
942 | }
943 |
944 | #[test]
945 | fn test_size() {
946 | assert_eq!(
947 | Type::UInt8.size(
948 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
949 | ),
950 | 1,
951 | );
952 | assert_eq!(
953 | Type::I8Vec2.size(
954 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
955 | ),
956 | 2,
957 | );
958 | assert_eq!(
959 | Type::I8Vec3.size(
960 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
961 | ),
962 | 3,
963 | );
964 | assert_eq!(
965 | Type::Mat4x3.size(
966 | Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
967 | ),
968 | 3 * 4 * 4 + 3 * 4,
969 | );
970 | }
971 |
972 | #[test]
973 | fn test_iterator() {
974 | let default_layout = Layout {
975 | std: LayoutStd::Std140,
976 | major: MajorAxis::Column,
977 | };
978 |
979 | // The components of a vec4 are packed tightly
980 | let vec4_offsets =
981 | Type::Vec4.offsets(default_layout).collect::<Vec<usize>>();
982 | assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);
983 |
984 | // The major axis setting shouldn’t affect vectors
985 | let row_major_layout = Layout {
986 | std: LayoutStd::Std140,
987 | major: MajorAxis::Row,
988 | };
989 | let vec4_offsets =
990 | Type::Vec4.offsets(row_major_layout).collect::<Vec<usize>>();
991 | assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);
992 |
993 | // Components of a mat4 are packed tightly
994 | let mat4_offsets =
995 | Type::Mat4.offsets(default_layout).collect::<Vec<usize>>();
996 | assert_eq!(
997 | mat4_offsets,
998 | vec![0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60],
999 | );
1000 |
1001 | // If there are 3 rows then they are padded out to align
1002 | // each column to the size of a vec4
1003 | let mat4x3_offsets =
1004 | Type::Mat4x3.offsets(default_layout).collect::<Vec<usize>>();
1005 | assert_eq!(
1006 | mat4x3_offsets,
1007 | vec![0, 4, 8, 16, 20, 24, 32, 36, 40, 48, 52, 56],
1008 | );
1009 |
1010 | // If row-major order is being used then the offsets are still
1011 | // reported in column-major order but the addresses that they
1012 | // point to are in row-major order.
1013 | // [0 4 8]
1014 | // [16 20 24]
1015 | // [32 36 40]
1016 | // [48 52 56]
1017 | let mat3x4_offsets =
1018 | Type::Mat3x4.offsets(row_major_layout).collect::<Vec<usize>>();
1019 | assert_eq!(
1020 | mat3x4_offsets,
1021 | vec![0, 16, 32, 48, 4, 20, 36, 52, 8, 24, 40, 56],
1022 | );
1023 |
1024 | // Check the size_hint and count. The size hint is always
1025 | // exact and should match the count.
1026 | let mut iter = Type::Mat4.offsets(default_layout);
1027 |
1028 | for i in 0..16 {
1029 | let expected_count = 16 - i;
1030 | assert_eq!(
1031 | iter.size_hint(),
1032 | (expected_count, Some(expected_count))
1033 | );
1034 | assert_eq!(iter.clone().count(), expected_count);
1035 |
1036 | assert!(matches!(iter.next(), Some(_)));
1037 | }
1038 |
1039 | assert_eq!(iter.size_hint(), (0, Some(0)));
1040 | assert_eq!(iter.clone().count(), 0);
1041 | assert!(matches!(iter.next(), None));
1042 | }
1043 |
1044 | #[test]
1045 | fn test_compare() {
1046 | // Test every comparison mode
1047 | assert!(Comparison::Equal.compare(
1048 | &Default::default(), // tolerance
1049 | Type::UInt8,
1050 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1051 | &[5],
1052 | &[5],
1053 | ));
1054 | assert!(!Comparison::Equal.compare(
1055 | &Default::default(), // tolerance
1056 | Type::UInt8,
1057 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1058 | &[6],
1059 | &[5],
1060 | ));
1061 | assert!(Comparison::NotEqual.compare(
1062 | &Default::default(), // tolerance
1063 | Type::UInt8,
1064 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1065 | &[6],
1066 | &[5],
1067 | ));
1068 | assert!(!Comparison::NotEqual.compare(
1069 | &Default::default(), // tolerance
1070 | Type::UInt8,
1071 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1072 | &[5],
1073 | &[5],
1074 | ));
1075 | assert!(Comparison::Less.compare(
1076 | &Default::default(), // tolerance
1077 | Type::UInt8,
1078 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1079 | &[4],
1080 | &[5],
1081 | ));
1082 | assert!(!Comparison::Less.compare(
1083 | &Default::default(), // tolerance
1084 | Type::UInt8,
1085 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1086 | &[5],
1087 | &[5],
1088 | ));
1089 | assert!(!Comparison::Less.compare(
1090 | &Default::default(), // tolerance
1091 | Type::UInt8,
1092 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1093 | &[6],
1094 | &[5],
1095 | ));
1096 | assert!(Comparison::LessEqual.compare(
1097 | &Default::default(), // tolerance
1098 | Type::UInt8,
1099 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1100 | &[4],
1101 | &[5],
1102 | ));
1103 | assert!(Comparison::LessEqual.compare(
1104 | &Default::default(), // tolerance
1105 | Type::UInt8,
1106 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1107 | &[5],
1108 | &[5],
1109 | ));
1110 | assert!(!Comparison::LessEqual.compare(
1111 | &Default::default(), // tolerance
1112 | Type::UInt8,
1113 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1114 | &[6],
1115 | &[5],
1116 | ));
1117 | assert!(Comparison::Greater.compare(
1118 | &Default::default(), // tolerance
1119 | Type::UInt8,
1120 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1121 | &[5],
1122 | &[4],
1123 | ));
1124 | assert!(!Comparison::Greater.compare(
1125 | &Default::default(), // tolerance
1126 | Type::UInt8,
1127 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1128 | &[5],
1129 | &[5],
1130 | ));
1131 | assert!(!Comparison::Greater.compare(
1132 | &Default::default(), // tolerance
1133 | Type::UInt8,
1134 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1135 | &[5],
1136 | &[6],
1137 | ));
1138 | assert!(Comparison::GreaterEqual.compare(
1139 | &Default::default(), // tolerance
1140 | Type::UInt8,
1141 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1142 | &[5],
1143 | &[4],
1144 | ));
1145 | assert!(Comparison::GreaterEqual.compare(
1146 | &Default::default(), // tolerance
1147 | Type::UInt8,
1148 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1149 | &[5],
1150 | &[5],
1151 | ));
1152 | assert!(!Comparison::GreaterEqual.compare(
1153 | &Default::default(), // tolerance
1154 | Type::UInt8,
1155 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1156 | &[5],
1157 | &[6],
1158 | ));
1159 | assert!(Comparison::FuzzyEqual.compare(
1160 | &Tolerance::new([1.0; 4], false),
1161 | Type::Float,
1162 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1163 | &5.0f32.to_ne_bytes(),
1164 | &5.9f32.to_ne_bytes(),
1165 | ));
1166 | assert!(!Comparison::FuzzyEqual.compare(
1167 | &Tolerance::new([1.0; 4], false),
1168 | Type::Float,
1169 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1170 | &5.0f32.to_ne_bytes(),
1171 | &6.1f32.to_ne_bytes(),
1172 | ));
1173 |
1174 | // Test all types
1175 |
1176 | macro_rules! test_compare_type_with_values {
1177 | ($type_enum:expr, $low_value:expr, $high_value:expr) => {
1178 | assert!(Comparison::Less.compare(
1179 | &Default::default(),
1180 | $type_enum,
1181 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1182 | &$low_value.to_ne_bytes(),
1183 | &$high_value.to_ne_bytes(),
1184 | ));
1185 | };
1186 | }
1187 |
1188 | macro_rules! test_compare_simple_type {
1189 | ($type_num:expr, $rust_type:ty) => {
1190 | test_compare_type_with_values!(
1191 | $type_num,
1192 | 0 as $rust_type,
1193 | <$rust_type>::MAX
1194 | );
1195 | };
1196 | }
1197 |
1198 | test_compare_simple_type!(Type::Int, i32);
1199 | test_compare_simple_type!(Type::UInt, u32);
1200 | test_compare_simple_type!(Type::Int8, i8);
1201 | test_compare_simple_type!(Type::UInt8, u8);
1202 | test_compare_simple_type!(Type::Int16, i16);
1203 | test_compare_simple_type!(Type::UInt16, u16);
1204 | test_compare_simple_type!(Type::Int64, i64);
1205 | test_compare_simple_type!(Type::UInt64, u64);
1206 | test_compare_simple_type!(Type::Float, f32);
1207 | test_compare_simple_type!(Type::Double, f64);
1208 | test_compare_type_with_values!(
1209 | Type::Float16,
1210 | half_float::from_f32(-100.0),
1211 | half_float::from_f32(300.0)
1212 | );
1213 |
1214 | macro_rules! test_compare_big_type {
1215 | ($type_enum:expr, $bytes:expr) => {
1216 | let layout = Layout {
1217 | std: LayoutStd::Std140,
1218 | major: MajorAxis::Column
1219 | };
1220 | let mut bytes = Vec::new();
1221 | assert_eq!($type_enum.size(layout) % $bytes.len(), 0);
1222 | for _ in 0..($type_enum.size(layout) / $bytes.len()) {
1223 | bytes.extend_from_slice($bytes);
1224 | }
1225 | assert!(Comparison::FuzzyEqual.compare(
1226 | &Default::default(),
1227 | $type_enum,
1228 | layout,
1229 | &bytes,
1230 | &bytes,
1231 | ));
1232 | };
1233 | }
1234 |
1235 | test_compare_big_type!(
1236 | Type::F16Vec2,
1237 | &half_float::from_f32(1.0).to_ne_bytes()
1238 | );
1239 | test_compare_big_type!(
1240 | Type::F16Vec3,
1241 | &half_float::from_f32(1.0).to_ne_bytes()
1242 | );
1243 | test_compare_big_type!(
1244 | Type::F16Vec4,
1245 | &half_float::from_f32(1.0).to_ne_bytes()
1246 | );
1247 |
1248 | test_compare_big_type!(Type::Vec2, &42.0f32.to_ne_bytes());
1249 | test_compare_big_type!(Type::Vec3, &24.0f32.to_ne_bytes());
1250 | test_compare_big_type!(Type::Vec4, &6.0f32.to_ne_bytes());
1251 | test_compare_big_type!(Type::DVec2, &42.0f64.to_ne_bytes());
1252 | test_compare_big_type!(Type::DVec3, &24.0f64.to_ne_bytes());
1253 | test_compare_big_type!(Type::DVec4, &6.0f64.to_ne_bytes());
1254 | test_compare_big_type!(Type::IVec2, &(-42i32).to_ne_bytes());
1255 | test_compare_big_type!(Type::IVec3, &(-24i32).to_ne_bytes());
1256 | test_compare_big_type!(Type::IVec4, &(-6i32).to_ne_bytes());
1257 | test_compare_big_type!(Type::UVec2, &42u32.to_ne_bytes());
1258 | test_compare_big_type!(Type::UVec3, &24u32.to_ne_bytes());
1259 | test_compare_big_type!(Type::UVec4, &6u32.to_ne_bytes());
1260 | test_compare_big_type!(Type::I8Vec2, &(-42i8).to_ne_bytes());
1261 | test_compare_big_type!(Type::I8Vec3, &(-24i8).to_ne_bytes());
1262 | test_compare_big_type!(Type::I8Vec4, &(-6i8).to_ne_bytes());
1263 | test_compare_big_type!(Type::U8Vec2, &42u8.to_ne_bytes());
1264 | test_compare_big_type!(Type::U8Vec3, &24u8.to_ne_bytes());
1265 | test_compare_big_type!(Type::U8Vec4, &6u8.to_ne_bytes());
1266 | test_compare_big_type!(Type::I16Vec2, &(-42i16).to_ne_bytes());
1267 | test_compare_big_type!(Type::I16Vec3, &(-24i16).to_ne_bytes());
1268 | test_compare_big_type!(Type::I16Vec4, &(-6i16).to_ne_bytes());
1269 | test_compare_big_type!(Type::U16Vec2, &42u16.to_ne_bytes());
1270 | test_compare_big_type!(Type::U16Vec3, &24u16.to_ne_bytes());
1271 | test_compare_big_type!(Type::U16Vec4, &6u16.to_ne_bytes());
1272 | test_compare_big_type!(Type::I64Vec2, &(-42i64).to_ne_bytes());
1273 | test_compare_big_type!(Type::I64Vec3, &(-24i64).to_ne_bytes());
1274 | test_compare_big_type!(Type::I64Vec4, &(-6i64).to_ne_bytes());
1275 | test_compare_big_type!(Type::U64Vec2, &42u64.to_ne_bytes());
1276 | test_compare_big_type!(Type::U64Vec3, &24u64.to_ne_bytes());
1277 | test_compare_big_type!(Type::U64Vec4, &6u64.to_ne_bytes());
1278 | test_compare_big_type!(Type::Mat2, &42.0f32.to_ne_bytes());
1279 | test_compare_big_type!(Type::Mat2x3, &42.0f32.to_ne_bytes());
1280 | test_compare_big_type!(Type::Mat2x4, &42.0f32.to_ne_bytes());
1281 | test_compare_big_type!(Type::Mat3, &24.0f32.to_ne_bytes());
1282 | test_compare_big_type!(Type::Mat3x2, &24.0f32.to_ne_bytes());
1283 | test_compare_big_type!(Type::Mat3x4, &24.0f32.to_ne_bytes());
1284 | test_compare_big_type!(Type::Mat4, &6.0f32.to_ne_bytes());
1285 | test_compare_big_type!(Type::Mat4x2, &6.0f32.to_ne_bytes());
1286 | test_compare_big_type!(Type::Mat4x3, &6.0f32.to_ne_bytes());
1287 | test_compare_big_type!(Type::DMat2, &42.0f64.to_ne_bytes());
1288 | test_compare_big_type!(Type::DMat2x3, &42.0f64.to_ne_bytes());
1289 | test_compare_big_type!(Type::DMat2x4, &42.0f64.to_ne_bytes());
1290 | test_compare_big_type!(Type::DMat3, &24.0f64.to_ne_bytes());
1291 | test_compare_big_type!(Type::DMat3x2, &24.0f64.to_ne_bytes());
1292 | test_compare_big_type!(Type::DMat3x4, &24.0f64.to_ne_bytes());
1293 | test_compare_big_type!(Type::DMat4, &6.0f64.to_ne_bytes());
1294 | test_compare_big_type!(Type::DMat4x2, &6.0f64.to_ne_bytes());
1295 | test_compare_big_type!(Type::DMat4x3, &6.0f64.to_ne_bytes());
1296 |
1297 | // Check that it the comparison fails if a value other than
1298 | // the first one is different
1299 | assert!(!Comparison::Equal.compare(
1300 | &Default::default(), // tolerance
1301 | Type::U8Vec4,
1302 | Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
1303 | &[5, 6, 7, 8],
1304 | &[5, 6, 7, 9],
1305 | ));
1306 | }
1307 |
1308 | #[test]
1309 | fn test_from_glsl_type() {
1310 | assert_eq!(Type::from_glsl_type("vec3"), Some(Type::Vec3));
1311 | assert_eq!(Type::from_glsl_type("dvec4"), Some(Type::DVec4));
1312 | assert_eq!(Type::from_glsl_type("i64vec4"), Some(Type::I64Vec4));
1313 | assert_eq!(Type::from_glsl_type("uint"), Some(Type::UInt));
1314 | assert_eq!(Type::from_glsl_type("uint16_t"), Some(Type::UInt16));
1315 | assert_eq!(Type::from_glsl_type("dvec5"), None);
1316 |
1317 | // Check that all types can be found with the binary search
1318 | for &(type_name, slot_type) in GLSL_TYPE_NAMES.iter() {
1319 | assert_eq!(Type::from_glsl_type(type_name), Some(slot_type));
1320 | }
1321 | }
1322 |
1323 | #[test]
1324 | fn test_comparison_from_operator() {
1325 | // If a comparison is added then please add it to this test too
1326 | assert_eq!(COMPARISON_NAMES.len(), 7);
1327 | assert_eq!(
1328 | Comparison::from_operator("!="),
1329 | Some(Comparison::NotEqual),
1330 | );
1331 | assert_eq!(
1332 | Comparison::from_operator("<"),
1333 | Some(Comparison::Less),
1334 | );
1335 | assert_eq!(
1336 | Comparison::from_operator("<="),
1337 | Some(Comparison::LessEqual),
1338 | );
1339 | assert_eq!(
1340 | Comparison::from_operator("=="),
1341 | Some(Comparison::Equal),
1342 | );
1343 | assert_eq!(
1344 | Comparison::from_operator(">"),
1345 | Some(Comparison::Greater),
1346 | );
1347 | assert_eq!(
1348 | Comparison::from_operator(">="),
1349 | Some(Comparison::GreaterEqual),
1350 | );
1351 | assert_eq!(
1352 | Comparison::from_operator("~="),
1353 | Some(Comparison::FuzzyEqual),
1354 | );
1355 |
1356 | assert_eq!(Comparison::from_operator("<=>"), None);
1357 | }
1358 |
1359 | #[track_caller]
1360 | fn check_base_type_format(
1361 | base_type: BaseType,
1362 | bytes: &[u8],
1363 | expected: &str,
1364 | ) {
1365 | assert_eq!(
1366 | &BaseTypeInSlice::new(base_type, bytes).to_string(),
1367 | expected
1368 | );
1369 | }
1370 |
1371 | #[test]
1372 | fn base_type_format() {
1373 | check_base_type_format(BaseType::Int, &1234i32.to_ne_bytes(), "1234");
1374 | check_base_type_format(
1375 | BaseType::UInt,
1376 | &u32::MAX.to_ne_bytes(),
1377 | "4294967295",
1378 | );
1379 | check_base_type_format(BaseType::Int8, &(-4i8).to_ne_bytes(), "-4");
1380 | check_base_type_format(BaseType::UInt8, &129u8.to_ne_bytes(), "129");
1381 | check_base_type_format(BaseType::Int16, &(-4i16).to_ne_bytes(), "-4");
1382 | check_base_type_format(
1383 | BaseType::UInt16,
1384 | &32769u16.to_ne_bytes(),
1385 | "32769"
1386 | );
1387 | check_base_type_format(BaseType::Int64, &(-4i64).to_ne_bytes(), "-4");
1388 | check_base_type_format(
1389 | BaseType::UInt64,
1390 | &u64::MAX.to_ne_bytes(),
1391 | "18446744073709551615"
1392 | );
1393 | check_base_type_format(
1394 | BaseType::Float16,
1395 | &0xc000u16.to_ne_bytes(),
1396 | "-2"
1397 | );
1398 | check_base_type_format(
1399 | BaseType::Float,
1400 | &2.0f32.to_ne_bytes(),
1401 | "2"
1402 | );
1403 | check_base_type_format(
1404 | BaseType::Double,
1405 | &2.0f64.to_ne_bytes(),
1406 | "2"
1407 | );
1408 | }
1409 | }
1410 |
```
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
```rust
1 | use anyhow::Result;
2 | use clap::Parser;
3 | use image::codecs::pnm::PnmDecoder;
4 | use image::{DynamicImage, ImageError, RgbImage};
5 | use rmcp::{
6 | Error as McpError, RoleServer, ServerHandler, ServiceExt, const_string, model::*, schemars,
7 | service::RequestContext, tool, transport::stdio,
8 | };
9 | use serde_json::json;
10 | use shaderc::{self, CompileOptions, Compiler, OptimizationLevel, ShaderKind};
11 | use std::fs::File;
12 | use std::io::BufReader;
13 | use std::path::{Path, PathBuf};
14 | use std::sync::Arc;
15 | use tokio::sync::Mutex;
16 | use tracing_subscriber::{self, EnvFilter};
17 |
18 | pub fn read_and_decode_ppm_file<P: AsRef<Path>>(path: P) -> Result<RgbImage, ImageError> {
19 | let file = File::open(path)?;
20 | let reader = BufReader::new(file);
21 | let decoder: PnmDecoder<BufReader<File>> = PnmDecoder::new(reader)?;
22 | let dynamic_image = DynamicImage::from_decoder(decoder)?;
23 | let rgb_image = dynamic_image.into_rgb8();
24 | Ok(rgb_image)
25 | }
26 |
27 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
28 | pub enum ShaderStage {
29 | #[schemars(description = "Vertex processing stage (transforms vertices)")]
30 | Vert,
31 | #[schemars(description = "Fragment processing stage (determines pixel colors)")]
32 | Frag,
33 | #[schemars(description = "Tessellation control stage for patch control points")]
34 | Tesc,
35 | #[schemars(description = "Tessellation evaluation stage for computing tessellated geometry")]
36 | Tese,
37 | #[schemars(description = "Geometry stage for generating/modifying primitives")]
38 | Geom,
39 | #[schemars(description = "Compute stage for general-purpose computation")]
40 | Comp,
41 | }
42 |
43 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
44 | pub enum ShaderRunnerRequire {
45 | #[schemars(
46 | description = "Enables cooperative matrix operations with specified dimensions and data type"
47 | )]
48 | CooperativeMatrix {
49 | m: u32,
50 | n: u32,
51 | component_type: String,
52 | },
53 |
54 | #[schemars(
55 | description = "Specifies depth/stencil format for depth testing and stencil operations"
56 | )]
57 | DepthStencil(String),
58 |
59 | #[schemars(description = "Specifies framebuffer format for render target output")]
60 | Framebuffer(String),
61 |
62 | #[schemars(description = "Enables double-precision floating point operations in shaders")]
63 | ShaderFloat64,
64 |
65 | #[schemars(description = "Enables geometry shader stage for primitive manipulation")]
66 | GeometryShader,
67 |
68 | #[schemars(description = "Enables lines with width greater than 1.0 pixel")]
69 | WideLines,
70 |
71 | #[schemars(description = "Enables bitwise logical operations on framebuffer contents")]
72 | LogicOp,
73 |
74 | #[schemars(description = "Specifies required subgroup size for shader execution")]
75 | SubgroupSize(u32),
76 |
77 | #[schemars(description = "Enables memory stores and atomic operations in fragment shaders")]
78 | FragmentStoresAndAtomics,
79 |
80 | #[schemars(description = "Enables shaders to use raw buffer addresses")]
81 | BufferDeviceAddress,
82 | }
83 |
84 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
85 | pub enum ShaderRunnerPass {
86 | #[schemars(
87 | description = "Use a built-in pass-through vertex shader (forwards attributes to fragment shader)"
88 | )]
89 | VertPassthrough,
90 |
91 | #[schemars(description = "Use compiled vertex shader from specified SPIR-V assembly file")]
92 | VertSpirv {
93 | #[schemars(
94 | description = "Path to the compiled vertex shader SPIR-V assembly (.spvasm) file"
95 | )]
96 | vert_spvasm_path: String,
97 | },
98 |
99 | #[schemars(description = "Use compiled fragment shader from specified SPIR-V assembly file")]
100 | FragSpirv {
101 | #[schemars(
102 | description = "Path to the compiled fragment shader SPIR-V assembly (.spvasm) file"
103 | )]
104 | frag_spvasm_path: String,
105 | },
106 |
107 | #[schemars(description = "Use compiled compute shader from specified SPIR-V assembly file")]
108 | CompSpirv {
109 | #[schemars(
110 | description = "Path to the compiled compute shader SPIR-V assembly (.spvasm) file"
111 | )]
112 | comp_spvasm_path: String,
113 | },
114 |
115 | #[schemars(description = "Use compiled geometry shader from specified SPIR-V assembly file")]
116 | GeomSpirv {
117 | #[schemars(
118 | description = "Path to the compiled geometry shader SPIR-V assembly (.spvasm) file"
119 | )]
120 | geom_spvasm_path: String,
121 | },
122 |
123 | #[schemars(
124 | description = "Use compiled tessellation control shader from specified SPIR-V assembly file"
125 | )]
126 | TescSpirv {
127 | #[schemars(
128 | description = "Path to the compiled tessellation control shader SPIR-V assembly (.spvasm) file"
129 | )]
130 | tesc_spvasm_path: String,
131 | },
132 |
133 | #[schemars(
134 | description = "Use compiled tessellation evaluation shader from specified SPIR-V assembly file"
135 | )]
136 | TeseSpirv {
137 | #[schemars(
138 | description = "Path to the compiled tessellation evaluation shader SPIR-V assembly (.spvasm) file"
139 | )]
140 | tese_spvasm_path: String,
141 | },
142 | }
143 |
144 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
145 | pub enum ShaderRunnerVertexData {
146 | #[schemars(description = "Defines an attribute format at a shader location/binding")]
147 | AttributeFormat {
148 | #[schemars(description = "Location/binding number in the shader")]
149 | location: u32,
150 |
151 | #[schemars(description = "Format name (e.g., R32G32_SFLOAT, A8B8G8R8_UNORM_PACK32)")]
152 | format: String,
153 | },
154 |
155 | #[schemars(description = "2D position data (for R32G32_SFLOAT format)")]
156 | Vec2 {
157 | #[schemars(description = "X coordinate value")]
158 | x: f32,
159 |
160 | #[schemars(description = "Y coordinate value")]
161 | y: f32,
162 | },
163 |
164 | #[schemars(description = "3D position data (for R32G32B32_SFLOAT format)")]
165 | Vec3 {
166 | #[schemars(description = "X coordinate value")]
167 | x: f32,
168 |
169 | #[schemars(description = "Y coordinate value")]
170 | y: f32,
171 |
172 | #[schemars(description = "Z coordinate value")]
173 | z: f32,
174 | },
175 |
176 | #[schemars(description = "4D position/color data (for R32G32B32A32_SFLOAT format)")]
177 | Vec4 {
178 | #[schemars(description = "X coordinate or Red component")]
179 | x: f32,
180 |
181 | #[schemars(description = "Y coordinate or Green component")]
182 | y: f32,
183 |
184 | #[schemars(description = "Z coordinate or Blue component")]
185 | z: f32,
186 |
187 | #[schemars(description = "W coordinate or Alpha component")]
188 | w: f32,
189 | },
190 |
191 | #[schemars(description = "RGB color data (for R8G8B8_UNORM format)")]
192 | RGB {
193 | #[schemars(description = "Red component (0-255)")]
194 | r: u8,
195 |
196 | #[schemars(description = "Green component (0-255)")]
197 | g: u8,
198 |
199 | #[schemars(description = "Blue component (0-255)")]
200 | b: u8,
201 | },
202 |
203 | #[schemars(description = "ARGB color as hex (for A8B8G8R8_UNORM_PACK32 format)")]
204 | Hex {
205 | #[schemars(description = "Color as hex string (0xAARRGGBB format)")]
206 | value: String,
207 | },
208 |
209 | #[schemars(description = "Generic data components for custom formats")]
210 | GenericComponents {
211 | #[schemars(description = "Component values as strings, interpreted by format")]
212 | components: Vec<String>,
213 | },
214 | }
215 |
216 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
217 | pub enum ShaderRunnerTest {
218 | #[schemars(description = "Set fragment shader entrypoint function name")]
219 | FragmentEntrypoint {
220 | #[schemars(description = "Function name to use as entrypoint")]
221 | name: String,
222 | },
223 |
224 | #[schemars(description = "Set vertex shader entrypoint function name")]
225 | VertexEntrypoint {
226 | #[schemars(description = "Function name to use as entrypoint")]
227 | name: String,
228 | },
229 |
230 | #[schemars(description = "Set compute shader entrypoint function name")]
231 | ComputeEntrypoint {
232 | #[schemars(description = "Function name to use as entrypoint")]
233 | name: String,
234 | },
235 |
236 | #[schemars(description = "Set geometry shader entrypoint function name")]
237 | GeometryEntrypoint {
238 | #[schemars(description = "Function name to use as entrypoint")]
239 | name: String,
240 | },
241 |
242 | #[schemars(description = "Draw a rectangle using normalized device coordinates")]
243 | DrawRect {
244 | #[schemars(description = "X coordinate of lower-left corner")]
245 | x: f32,
246 |
247 | #[schemars(description = "Y coordinate of lower-left corner")]
248 | y: f32,
249 |
250 | #[schemars(description = "Width of the rectangle")]
251 | width: f32,
252 |
253 | #[schemars(description = "Height of the rectangle")]
254 | height: f32,
255 | },
256 |
257 | #[schemars(description = "Draw primitives using vertex data")]
258 | DrawArrays {
259 | #[schemars(description = "Primitive type (TRIANGLE_LIST, POINT_LIST, etc.)")]
260 | primitive_type: String,
261 |
262 | #[schemars(description = "Index of first vertex")]
263 | first: u32,
264 |
265 | #[schemars(description = "Number of vertices to draw")]
266 | count: u32,
267 | },
268 |
269 | #[schemars(description = "Draw primitives using indexed vertex data")]
270 | DrawArraysIndexed {
271 | #[schemars(description = "Primitive type (TRIANGLE_LIST, etc.)")]
272 | primitive_type: String,
273 |
274 | #[schemars(description = "Index of first index")]
275 | first: u32,
276 |
277 | #[schemars(description = "Number of indices to use")]
278 | count: u32,
279 | },
280 |
281 | #[schemars(description = "Create or initialize a Shader Storage Buffer Object (SSBO)")]
282 | SSBO {
283 | #[schemars(description = "Binding point in the shader")]
284 | binding: u32,
285 |
286 | #[schemars(description = "Size in bytes (if data not provided)")]
287 | size: Option<u32>,
288 |
289 | #[schemars(description = "Initial buffer contents")]
290 | data: Option<Vec<u8>>,
291 |
292 | #[schemars(description = "Descriptor set number (default: 0)")]
293 | descriptor_set: Option<u32>,
294 | },
295 |
296 | #[schemars(description = "Update a portion of an SSBO with new data")]
297 | SSBOSubData {
298 | #[schemars(description = "Binding point in the shader")]
299 | binding: u32,
300 |
301 | #[schemars(description = "Data type (float, vec4, etc.)")]
302 | data_type: String,
303 |
304 | #[schemars(description = "Byte offset into the buffer")]
305 | offset: u32,
306 |
307 | #[schemars(description = "Values to write (as strings)")]
308 | values: Vec<String>,
309 |
310 | #[schemars(description = "Descriptor set number (default: 0)")]
311 | descriptor_set: Option<u32>,
312 | },
313 |
314 | #[schemars(description = "Create or initialize a Uniform Buffer Object (UBO)")]
315 | UBO {
316 | #[schemars(description = "Binding point in the shader")]
317 | binding: u32,
318 |
319 | #[schemars(description = "Buffer contents")]
320 | data: Vec<u8>,
321 |
322 | #[schemars(description = "Descriptor set number (default: 0)")]
323 | descriptor_set: Option<u32>,
324 | },
325 |
326 | #[schemars(description = "Update a portion of a UBO with new data")]
327 | UBOSubData {
328 | #[schemars(description = "Binding point in the shader")]
329 | binding: u32,
330 |
331 | #[schemars(description = "Data type (float, vec4, etc.)")]
332 | data_type: String,
333 |
334 | #[schemars(description = "Byte offset into the buffer")]
335 | offset: u32,
336 |
337 | #[schemars(description = "Values to write (as strings)")]
338 | values: Vec<String>,
339 |
340 | #[schemars(description = "Descriptor set number (default: 0)")]
341 | descriptor_set: Option<u32>,
342 | },
343 |
344 | #[schemars(description = "Set memory layout for buffer data")]
345 | BufferLayout {
346 | #[schemars(description = "Buffer type (ubo, ssbo)")]
347 | buffer_type: String,
348 |
349 | #[schemars(description = "Layout specification (std140, std430, row_major, column_major)")]
350 | layout_type: String,
351 | },
352 |
353 | #[schemars(description = "Set push constant values")]
354 | Push {
355 | #[schemars(description = "Data type (float, vec4, etc.)")]
356 | data_type: String,
357 |
358 | #[schemars(description = "Byte offset into push constant block")]
359 | offset: u32,
360 |
361 | #[schemars(description = "Values to write (as strings)")]
362 | values: Vec<String>,
363 | },
364 |
365 | #[schemars(description = "Set memory layout for push constants")]
366 | PushLayout {
367 | #[schemars(description = "Layout specification (std140, std430)")]
368 | layout_type: String,
369 | },
370 |
371 | #[schemars(description = "Execute compute shader with specified workgroup counts")]
372 | Compute {
373 | #[schemars(description = "Number of workgroups in X dimension")]
374 | x: u32,
375 |
376 | #[schemars(description = "Number of workgroups in Y dimension")]
377 | y: u32,
378 |
379 | #[schemars(description = "Number of workgroups in Z dimension")]
380 | z: u32,
381 | },
382 |
383 | #[schemars(description = "Verify framebuffer or buffer contents match expected values")]
384 | Probe {
385 | #[schemars(description = "Probe type (all, rect, ssbo, etc.)")]
386 | probe_type: String,
387 |
388 | #[schemars(description = "Component format (rgba, rgb, etc.)")]
389 | format: String,
390 |
391 | #[schemars(description = "Parameters (coordinates, expected values)")]
392 | args: Vec<String>,
393 | },
394 |
395 | #[schemars(description = "Verify contents using normalized (0-1) coordinates")]
396 | RelativeProbe {
397 | #[schemars(description = "Probe type (rect, etc.)")]
398 | probe_type: String,
399 |
400 | #[schemars(description = "Component format (rgba, rgb, etc.)")]
401 | format: String,
402 |
403 | #[schemars(description = "Parameters (coordinates, expected values)")]
404 | args: Vec<String>,
405 | },
406 |
407 | #[schemars(description = "Set acceptable error margin for value comparisons")]
408 | Tolerance {
409 | #[schemars(description = "Error margins (absolute or percentage with % suffix)")]
410 | values: Vec<f32>,
411 | },
412 |
413 | #[schemars(description = "Clear the framebuffer to default values")]
414 | Clear,
415 |
416 | #[schemars(description = "Enable/disable depth testing")]
417 | DepthTestEnable {
418 | #[schemars(description = "True to enable depth testing")]
419 | enable: bool,
420 | },
421 |
422 | #[schemars(description = "Enable/disable writing to depth buffer")]
423 | DepthWriteEnable {
424 | #[schemars(description = "True to enable depth writes")]
425 | enable: bool,
426 | },
427 |
428 | #[schemars(description = "Set depth comparison function")]
429 | DepthCompareOp {
430 | #[schemars(description = "Function name (VK_COMPARE_OP_LESS, etc.)")]
431 | op: String,
432 | },
433 |
434 | #[schemars(description = "Enable/disable stencil testing")]
435 | StencilTestEnable {
436 | #[schemars(description = "True to enable stencil testing")]
437 | enable: bool,
438 | },
439 |
440 | #[schemars(description = "Define which winding order is front-facing")]
441 | FrontFace {
442 | #[schemars(description = "Mode (VK_FRONT_FACE_CLOCKWISE, etc.)")]
443 | mode: String,
444 | },
445 |
446 | #[schemars(description = "Configure stencil operation for a face")]
447 | StencilOp {
448 | #[schemars(description = "Face (front, back)")]
449 | face: String,
450 |
451 | #[schemars(description = "Operation name (passOp, failOp, etc.)")]
452 | op_name: String,
453 |
454 | #[schemars(description = "Value (VK_STENCIL_OP_REPLACE, etc.)")]
455 | value: String,
456 | },
457 |
458 | #[schemars(description = "Set reference value for stencil comparisons")]
459 | StencilReference {
460 | #[schemars(description = "Face (front, back)")]
461 | face: String,
462 |
463 | #[schemars(description = "Reference value")]
464 | value: u32,
465 | },
466 |
467 | #[schemars(description = "Set stencil comparison function")]
468 | StencilCompareOp {
469 | #[schemars(description = "Face (front, back)")]
470 | face: String,
471 |
472 | #[schemars(description = "Function (VK_COMPARE_OP_EQUAL, etc.)")]
473 | op: String,
474 | },
475 |
476 | #[schemars(description = "Control which color channels can be written")]
477 | ColorWriteMask {
478 | #[schemars(
479 | description = "Bit flags (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT, etc.)"
480 | )]
481 | mask: String,
482 | },
483 |
484 | #[schemars(description = "Enable/disable logical operations on colors")]
485 | LogicOpEnable {
486 | #[schemars(description = "True to enable logic operations")]
487 | enable: bool,
488 | },
489 |
490 | #[schemars(description = "Set type of logical operation on colors")]
491 | LogicOp {
492 | #[schemars(description = "Operation (VK_LOGIC_OP_XOR, etc.)")]
493 | op: String,
494 | },
495 |
496 | #[schemars(description = "Set face culling mode")]
497 | CullMode {
498 | #[schemars(description = "Mode (VK_CULL_MODE_BACK_BIT, etc.)")]
499 | mode: String,
500 | },
501 |
502 | #[schemars(description = "Set width for line primitives")]
503 | LineWidth {
504 | #[schemars(description = "Width in pixels")]
505 | width: f32,
506 | },
507 |
508 | #[schemars(description = "Specify a feature required by the test")]
509 | Require {
510 | #[schemars(description = "Feature name (subgroup_size, depthstencil, etc.)")]
511 | feature: String,
512 |
513 | #[schemars(description = "Feature parameters")]
514 | parameters: Vec<String>,
515 | },
516 | }
517 |
518 | #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
519 | pub struct CompileRequest {
520 | #[schemars(description = "The shader stage to compile (vert, frag, comp, geom, tesc, tese)")]
521 | pub stage: ShaderStage,
522 | #[schemars(description = "GLSL shader source code to compile")]
523 | pub source: String,
524 | #[schemars(description = "Path where compiled SPIR-V assembly (.spvasm) will be saved")]
525 | pub tmp_output_path: String,
526 | }
527 | #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
528 | pub struct CompileShadersRequest {
529 | #[schemars(description = "List of shader compile requests (produces SPIR-V assemblies)")]
530 | pub requests: Vec<CompileRequest>,
531 | }
532 |
533 | #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
534 | pub struct CompileRunShadersRequest {
535 | #[schemars(
536 | description = "List of shader compile requests - each produces a SPIR-V assembly file"
537 | )]
538 | pub requests: Vec<CompileRequest>,
539 | #[schemars(description = "Optional hardware/feature requirements needed for shader execution")]
540 | pub requirements: Option<Vec<ShaderRunnerRequire>>,
541 | #[schemars(
542 | description = "Shader pipeline configuration (references compiled SPIR-V files by path)"
543 | )]
544 | pub passes: Vec<ShaderRunnerPass>,
545 | #[schemars(description = "Optional vertex data for rendering geometry")]
546 | pub vertex_data: Option<Vec<ShaderRunnerVertexData>>,
547 | #[schemars(description = "Test commands to execute (drawing, compute, verification, etc.)")]
548 | pub tests: Vec<ShaderRunnerTest>,
549 | #[schemars(description = "Optional path to save output image (PNG format)")]
550 | pub output_path: Option<String>,
551 | }
552 | #[derive(Clone)]
553 | pub struct ShadercVkrunnerMcp {}
554 | #[tool(tool_box)]
555 | impl ShadercVkrunnerMcp {
556 | #[must_use]
557 | pub fn new() -> Self {
558 | Self {}
559 | }
560 |
561 | #[tool(
562 | description = "REQUIRES: (1) Shader compile requests that produce SPIR-V assembly files, (2) Passes referencing these files by path, and (3) Test commands for drawing/computation. Workflow: First compile GLSL to SPIR-V, then reference compiled files in passes, then execute drawing commands in tests to render/compute. Tests MUST include draw/compute commands to produce visible output. Optional: requirements for hardware features, vertex data for geometry, output path for saving image. Every shader MUST have its compiled output referenced in passes."
563 | )]
564 | fn compile_run_shaders(
565 | &self,
566 | #[tool(aggr)] request: CompileRunShadersRequest,
567 | ) -> Result<CallToolResult, McpError> {
568 | use std::fs::File;
569 | use std::io::{Read, Write};
570 | use std::path::Path;
571 | use std::process::{Command, Stdio}; // Still needed for vkrunner
572 |
573 | fn io_err(e: std::io::Error) -> McpError {
574 | McpError::internal_error("IO operation failed", Some(json!({"error": e.to_string()})))
575 | }
576 |
577 | for req in &request.requests {
578 | let shader_kind = match req.stage {
579 | ShaderStage::Vert => ShaderKind::Vertex,
580 | ShaderStage::Frag => ShaderKind::Fragment,
581 | ShaderStage::Tesc => ShaderKind::TessControl,
582 | ShaderStage::Tese => ShaderKind::TessEvaluation,
583 | ShaderStage::Geom => ShaderKind::Geometry,
584 | ShaderStage::Comp => ShaderKind::Compute,
585 | };
586 |
587 | let stage_flag = match req.stage {
588 | ShaderStage::Vert => "vert",
589 | ShaderStage::Frag => "frag",
590 | ShaderStage::Tesc => "tesc",
591 | ShaderStage::Tese => "tese",
592 | ShaderStage::Geom => "geom",
593 | ShaderStage::Comp => "comp",
594 | };
595 |
596 | let tmp_output_path = if req.tmp_output_path.starts_with("/tmp") {
597 | req.tmp_output_path.clone()
598 | } else {
599 | format!("/tmp/{}", req.tmp_output_path)
600 | };
601 |
602 | if let Some(parent) = Path::new(&tmp_output_path).parent() {
603 | std::fs::create_dir_all(parent).map_err(|e| {
604 | McpError::internal_error(
605 | "Failed to create temporary output directory",
606 | Some(json!({"error": e.to_string()})),
607 | )
608 | })?;
609 | }
610 |
611 | // Create compiler and options
612 | let mut compiler = Compiler::new().ok().ok_or_else(|| {
613 | McpError::internal_error(
614 | "Failed to create shaderc compiler",
615 | Some(json!({"error": "Could not instantiate shaderc compiler"})),
616 | )
617 | })?;
618 |
619 | let mut options = CompileOptions::new().ok().ok_or_else(|| {
620 | McpError::internal_error(
621 | "Failed to create shaderc compile options",
622 | Some(json!({"error": "Could not create compiler options"})),
623 | )
624 | })?;
625 |
626 | // Set options equivalent to the CLI flags
627 | options.set_target_env(
628 | shaderc::TargetEnv::Vulkan,
629 | shaderc::EnvVersion::Vulkan1_4 as u32,
630 | );
631 | options.set_optimization_level(OptimizationLevel::Performance);
632 | options.set_generate_debug_info();
633 |
634 | // Compile to SPIR-V assembly
635 | let artifact = match compiler.compile_into_spirv_assembly(
636 | &req.source,
637 | shader_kind,
638 | "shader.glsl", // source name for error reporting
639 | "main", // entry point
640 | Some(&options),
641 | ) {
642 | Ok(artifact) => artifact,
643 | Err(e) => {
644 | let stage_name = match req.stage {
645 | ShaderStage::Vert => "Vertex",
646 | ShaderStage::Frag => "Fragment",
647 | ShaderStage::Tesc => "Tessellation Control",
648 | ShaderStage::Tese => "Tessellation Evaluation",
649 | ShaderStage::Geom => "Geometry",
650 | ShaderStage::Comp => "Compute",
651 | };
652 |
653 | let error_details = format!("{}", e);
654 |
655 | return Ok(CallToolResult::success(vec![Content::text(format!(
656 | "Shader compilation failed for {} shader:\n\nError:\n{}\n\nShader Source:\n{}\n",
657 | stage_name, error_details, req.source
658 | ))]));
659 | }
660 | };
661 |
662 | // Write the compiled SPIR-V assembly to the output file, filtering unsupported lines
663 | let spv_text = artifact.as_text();
664 | let filtered_spv = spv_text
665 | .lines()
666 | .filter(|l| !l.trim_start().starts_with("OpModuleProcessed"))
667 | .collect::<Vec<_>>()
668 | .join("\n");
669 | std::fs::write(&tmp_output_path, filtered_spv).map_err(|e| {
670 | McpError::internal_error(
671 | "Failed to write compiled shader to file",
672 | Some(json!({"error": e.to_string()})),
673 | )
674 | })?;
675 | }
676 |
677 | let shader_test_path = "/tmp/vkrunner_test.shader_test";
678 | let mut shader_test_file = File::create(shader_test_path).map_err(io_err)?;
679 |
680 | if let Some(requirements) = &request.requirements {
681 | if !requirements.is_empty() {
682 | writeln!(shader_test_file, "[require]").map_err(io_err)?;
683 |
684 | for req in requirements {
685 | match req {
686 | ShaderRunnerRequire::CooperativeMatrix {
687 | m,
688 | n,
689 | component_type,
690 | } => {
691 | writeln!(
692 | shader_test_file,
693 | "cooperative_matrix m={m} n={n} c={component_type}"
694 | )
695 | .map_err(io_err)?;
696 | }
697 | ShaderRunnerRequire::DepthStencil(format) => {
698 | writeln!(shader_test_file, "depthstencil {format}").map_err(io_err)?;
699 | }
700 | ShaderRunnerRequire::Framebuffer(format) => {
701 | writeln!(shader_test_file, "framebuffer {format}").map_err(io_err)?;
702 | }
703 | ShaderRunnerRequire::ShaderFloat64 => {
704 | writeln!(shader_test_file, "shaderFloat64").map_err(io_err)?;
705 | }
706 | ShaderRunnerRequire::GeometryShader => {
707 | writeln!(shader_test_file, "geometryShader").map_err(io_err)?;
708 | }
709 | ShaderRunnerRequire::WideLines => {
710 | writeln!(shader_test_file, "wideLines").map_err(io_err)?;
711 | }
712 | ShaderRunnerRequire::LogicOp => {
713 | writeln!(shader_test_file, "logicOp").map_err(io_err)?;
714 | }
715 | ShaderRunnerRequire::SubgroupSize(size) => {
716 | writeln!(shader_test_file, "subgroup_size {size}").map_err(io_err)?;
717 | }
718 | ShaderRunnerRequire::FragmentStoresAndAtomics => {
719 | writeln!(shader_test_file, "fragmentStoresAndAtomics")
720 | .map_err(io_err)?;
721 | }
722 | ShaderRunnerRequire::BufferDeviceAddress => {
723 | writeln!(shader_test_file, "bufferDeviceAddress").map_err(io_err)?;
724 | }
725 | }
726 | }
727 |
728 | writeln!(shader_test_file).map_err(io_err)?;
729 | }
730 | }
731 |
732 | for pass in &request.passes {
733 | match pass {
734 | ShaderRunnerPass::VertPassthrough => {
735 | writeln!(shader_test_file, "[vertex shader passthrough]").map_err(io_err)?;
736 | }
737 | ShaderRunnerPass::VertSpirv { vert_spvasm_path } => {
738 | writeln!(shader_test_file, "[vertex shader spirv]").map_err(io_err)?;
739 |
740 | let mut spvasm = String::new();
741 | let path = if vert_spvasm_path.starts_with("/tmp") {
742 | vert_spvasm_path.clone()
743 | } else {
744 | format!("/tmp/{vert_spvasm_path}")
745 | };
746 |
747 | File::open(&path)
748 | .map_err(|e| {
749 | McpError::internal_error(
750 | format!("Failed to open vertex shader SPIR-V file at {path}"),
751 | Some(json!({"error": e.to_string()})),
752 | )
753 | })?
754 | .read_to_string(&mut spvasm)
755 | .map_err(|e| {
756 | McpError::internal_error(
757 | "Failed to read vertex shader SPIR-V file",
758 | Some(json!({"error": e.to_string()})),
759 | )
760 | })?;
761 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
762 | }
763 | ShaderRunnerPass::FragSpirv { frag_spvasm_path } => {
764 | writeln!(shader_test_file, "[fragment shader spirv]").map_err(io_err)?;
765 |
766 | let mut spvasm = String::new();
767 | let path = if frag_spvasm_path.starts_with("/tmp") {
768 | frag_spvasm_path.clone()
769 | } else {
770 | format!("/tmp/{frag_spvasm_path}")
771 | };
772 |
773 | File::open(&path)
774 | .map_err(|e| {
775 | McpError::internal_error(
776 | format!("Failed to open fragment shader SPIR-V file at {path}"),
777 | Some(json!({"error": e.to_string()})),
778 | )
779 | })?
780 | .read_to_string(&mut spvasm)
781 | .map_err(|e| {
782 | McpError::internal_error(
783 | "Failed to read fragment shader SPIR-V file",
784 | Some(json!({"error": e.to_string()})),
785 | )
786 | })?;
787 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
788 | }
789 | ShaderRunnerPass::CompSpirv { comp_spvasm_path } => {
790 | writeln!(shader_test_file, "[compute shader spirv]").map_err(io_err)?;
791 |
792 | let mut spvasm = String::new();
793 | let path = if comp_spvasm_path.starts_with("/tmp") {
794 | comp_spvasm_path.clone()
795 | } else {
796 | format!("/tmp/{comp_spvasm_path}")
797 | };
798 |
799 | File::open(&path)
800 | .map_err(|e| {
801 | McpError::internal_error(
802 | format!("Failed to open compute shader SPIR-V file at {path}"),
803 | Some(json!({"error": e.to_string()})),
804 | )
805 | })?
806 | .read_to_string(&mut spvasm)
807 | .map_err(|e| {
808 | McpError::internal_error(
809 | "Failed to read compute shader SPIR-V file",
810 | Some(json!({"error": e.to_string()})),
811 | )
812 | })?;
813 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
814 | }
815 | ShaderRunnerPass::GeomSpirv { geom_spvasm_path } => {
816 | writeln!(shader_test_file, "[geometry shader spirv]").map_err(io_err)?;
817 |
818 | let mut spvasm = String::new();
819 | let path = if geom_spvasm_path.starts_with("/tmp") {
820 | geom_spvasm_path.clone()
821 | } else {
822 | format!("/tmp/{geom_spvasm_path}")
823 | };
824 |
825 | File::open(&path)
826 | .map_err(|e| {
827 | McpError::internal_error(
828 | format!("Failed to open geometry shader SPIR-V file at {path}"),
829 | Some(json!({"error": e.to_string()})),
830 | )
831 | })?
832 | .read_to_string(&mut spvasm)
833 | .map_err(|e| {
834 | McpError::internal_error(
835 | "Failed to read geometry shader SPIR-V file",
836 | Some(json!({"error": e.to_string()})),
837 | )
838 | })?;
839 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
840 | }
841 | ShaderRunnerPass::TescSpirv { tesc_spvasm_path } => {
842 | writeln!(shader_test_file, "[tessellation control shader spirv]")
843 | .map_err(io_err)?;
844 |
845 | let mut spvasm = String::new();
846 | let path = if tesc_spvasm_path.starts_with("/tmp") {
847 | tesc_spvasm_path.clone()
848 | } else {
849 | format!("/tmp/{tesc_spvasm_path}")
850 | };
851 |
852 | File::open(&path)
853 | .map_err(|e| {
854 | McpError::internal_error(
855 | format!(
856 | "Failed to open tessellation control shader SPIR-V file at {path}"
857 | ),
858 | Some(json!({"error": e.to_string()})),
859 | )
860 | })?
861 | .read_to_string(&mut spvasm)
862 | .map_err(|e| {
863 | McpError::internal_error(
864 | "Failed to read tessellation control shader SPIR-V file",
865 | Some(json!({"error": e.to_string()})),
866 | )
867 | })?;
868 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
869 | }
870 | ShaderRunnerPass::TeseSpirv { tese_spvasm_path } => {
871 | writeln!(shader_test_file, "[tessellation evaluation shader spirv]")
872 | .map_err(io_err)?;
873 |
874 | let mut spvasm = String::new();
875 | let path = if tese_spvasm_path.starts_with("/tmp") {
876 | tese_spvasm_path.clone()
877 | } else {
878 | format!("/tmp/{tese_spvasm_path}")
879 | };
880 |
881 | File::open(&path)
882 | .map_err(|e| {
883 | McpError::internal_error(
884 | format!(
885 | "Failed to open tessellation evaluation shader SPIR-V file at {path}"
886 | ),
887 | Some(json!({"error": e.to_string()})),
888 | )
889 | })?
890 | .read_to_string(&mut spvasm)
891 | .map_err(|e| {
892 | McpError::internal_error(
893 | "Failed to read tessellation evaluation shader SPIR-V file",
894 | Some(json!({"error": e.to_string()})),
895 | )
896 | })?;
897 | writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
898 | }
899 | }
900 |
901 | writeln!(shader_test_file).map_err(io_err)?;
902 | }
903 |
904 | if let Some(vertex_data) = &request.vertex_data {
905 | writeln!(shader_test_file, "[vertex data]").map_err(io_err)?;
906 |
907 | for data in vertex_data {
908 | if let ShaderRunnerVertexData::AttributeFormat { location, format } = data {
909 | writeln!(shader_test_file, "{location}/{format}").map_err(io_err)?;
910 | } else {
911 | match data {
912 | ShaderRunnerVertexData::Vec2 { x, y } => {
913 | write!(shader_test_file, "{x} {y}").map_err(io_err)?;
914 | }
915 | ShaderRunnerVertexData::Vec3 { x, y, z } => {
916 | write!(shader_test_file, "{x} {y} {z}").map_err(io_err)?;
917 | }
918 | ShaderRunnerVertexData::Vec4 { x, y, z, w } => {
919 | write!(shader_test_file, "{x} {y} {z} {w}").map_err(io_err)?;
920 | }
921 | ShaderRunnerVertexData::RGB { r, g, b } => {
922 | write!(shader_test_file, "{r} {g} {b}").map_err(io_err)?;
923 | }
924 | ShaderRunnerVertexData::Hex { value } => {
925 | write!(shader_test_file, "{value}").map_err(io_err)?;
926 | }
927 | ShaderRunnerVertexData::GenericComponents { components } => {
928 | for component in components {
929 | write!(shader_test_file, "{component} ").map_err(io_err)?;
930 | }
931 | }
932 | _ => {}
933 | }
934 | writeln!(shader_test_file).map_err(io_err)?;
935 | }
936 | }
937 |
938 | writeln!(shader_test_file).map_err(io_err)?;
939 | }
940 |
941 | writeln!(shader_test_file, "[test]").map_err(io_err)?;
942 |
943 | for test_cmd in &request.tests {
944 | match test_cmd {
945 | ShaderRunnerTest::FragmentEntrypoint { name } => {
946 | writeln!(shader_test_file, "fragment entrypoint {name}").map_err(io_err)?;
947 | }
948 | ShaderRunnerTest::VertexEntrypoint { name } => {
949 | writeln!(shader_test_file, "vertex entrypoint {name}").map_err(io_err)?;
950 | }
951 | ShaderRunnerTest::ComputeEntrypoint { name } => {
952 | writeln!(shader_test_file, "compute entrypoint {name}").map_err(io_err)?;
953 | }
954 | ShaderRunnerTest::GeometryEntrypoint { name } => {
955 | writeln!(shader_test_file, "geometry entrypoint {name}").map_err(io_err)?;
956 | }
957 | ShaderRunnerTest::DrawRect {
958 | x,
959 | y,
960 | width,
961 | height,
962 | } => {
963 | writeln!(shader_test_file, "draw rect {x} {y} {width} {height}")
964 | .map_err(io_err)?;
965 | }
966 | ShaderRunnerTest::DrawArrays {
967 | primitive_type,
968 | first,
969 | count,
970 | } => {
971 | writeln!(
972 | shader_test_file,
973 | "draw arrays {primitive_type} {first} {count}"
974 | )
975 | .map_err(io_err)?;
976 | }
977 | ShaderRunnerTest::DrawArraysIndexed {
978 | primitive_type,
979 | first,
980 | count,
981 | } => {
982 | writeln!(
983 | shader_test_file,
984 | "draw arrays indexed {primitive_type} {first} {count}"
985 | )
986 | .map_err(io_err)?;
987 | }
988 | ShaderRunnerTest::SSBO {
989 | binding,
990 | size,
991 | data,
992 | descriptor_set,
993 | } => {
994 | let set_prefix = if let Some(set) = descriptor_set {
995 | format!("{set}:")
996 | } else {
997 | String::new()
998 | };
999 |
1000 | if let Some(size) = size {
1001 | writeln!(shader_test_file, "ssbo {set_prefix}{binding} {size}")
1002 | .map_err(io_err)?;
1003 | } else if let Some(_data) = data {
1004 | writeln!(shader_test_file, "ssbo {set_prefix}{binding} data")
1005 | .map_err(io_err)?;
1006 | }
1007 | }
1008 | ShaderRunnerTest::SSBOSubData {
1009 | binding,
1010 | data_type,
1011 | offset,
1012 | values,
1013 | descriptor_set,
1014 | } => {
1015 | let set_prefix = if let Some(set) = descriptor_set {
1016 | format!("{set}:")
1017 | } else {
1018 | String::new()
1019 | };
1020 |
1021 | write!(
1022 | shader_test_file,
1023 | "ssbo {set_prefix}{binding} subdata {data_type} {offset}"
1024 | )
1025 | .map_err(io_err)?;
1026 | for value in values {
1027 | write!(shader_test_file, " {value}").map_err(io_err)?;
1028 | }
1029 | writeln!(shader_test_file).map_err(io_err)?;
1030 | }
1031 | ShaderRunnerTest::UBO {
1032 | binding,
1033 | data: _,
1034 | descriptor_set,
1035 | } => {
1036 | let set_prefix = if let Some(set) = descriptor_set {
1037 | format!("{set}:")
1038 | } else {
1039 | String::new()
1040 | };
1041 |
1042 | writeln!(shader_test_file, "ubo {set_prefix}{binding} data").map_err(io_err)?;
1043 | }
1044 | ShaderRunnerTest::UBOSubData {
1045 | binding,
1046 | data_type,
1047 | offset,
1048 | values,
1049 | descriptor_set,
1050 | } => {
1051 | let set_prefix = if let Some(set) = descriptor_set {
1052 | format!("{set}:")
1053 | } else {
1054 | String::new()
1055 | };
1056 |
1057 | write!(
1058 | shader_test_file,
1059 | "ubo {set_prefix}{binding} subdata {data_type} {offset}"
1060 | )
1061 | .map_err(io_err)?;
1062 | for value in values {
1063 | write!(shader_test_file, " {value}").map_err(io_err)?;
1064 | }
1065 | writeln!(shader_test_file).map_err(io_err)?;
1066 | }
1067 | ShaderRunnerTest::BufferLayout {
1068 | buffer_type,
1069 | layout_type,
1070 | } => {
1071 | writeln!(shader_test_file, "{buffer_type} layout {layout_type}")
1072 | .map_err(io_err)?;
1073 | }
1074 | ShaderRunnerTest::Push {
1075 | data_type,
1076 | offset,
1077 | values,
1078 | } => {
1079 | write!(shader_test_file, "push {data_type} {offset}").map_err(io_err)?;
1080 | for value in values {
1081 | write!(shader_test_file, " {value}").map_err(io_err)?;
1082 | }
1083 | writeln!(shader_test_file).map_err(io_err)?;
1084 | }
1085 | ShaderRunnerTest::PushLayout { layout_type } => {
1086 | writeln!(shader_test_file, "push layout {layout_type}").map_err(io_err)?;
1087 | }
1088 | ShaderRunnerTest::Compute { x, y, z } => {
1089 | writeln!(shader_test_file, "compute {x} {y} {z}").map_err(io_err)?;
1090 | }
1091 | ShaderRunnerTest::Probe {
1092 | probe_type,
1093 | format,
1094 | args,
1095 | } => {
1096 | write!(shader_test_file, "probe {probe_type} {format}").map_err(io_err)?;
1097 | for arg in args {
1098 | write!(shader_test_file, " {arg}").map_err(io_err)?;
1099 | }
1100 | writeln!(shader_test_file).map_err(io_err)?;
1101 | }
1102 | ShaderRunnerTest::RelativeProbe {
1103 | probe_type,
1104 | format,
1105 | args,
1106 | } => {
1107 | write!(shader_test_file, "relative probe {probe_type} {format}")
1108 | .map_err(io_err)?;
1109 | for arg in args {
1110 | write!(shader_test_file, " {arg}").map_err(io_err)?;
1111 | }
1112 | writeln!(shader_test_file).map_err(io_err)?;
1113 | }
1114 | ShaderRunnerTest::Tolerance { values } => {
1115 | write!(shader_test_file, "tolerance").map_err(io_err)?;
1116 | for value in values {
1117 | write!(shader_test_file, " {value}").map_err(io_err)?;
1118 | }
1119 | writeln!(shader_test_file).map_err(io_err)?;
1120 | }
1121 | ShaderRunnerTest::Clear => {
1122 | writeln!(shader_test_file, "clear").map_err(io_err)?;
1123 | }
1124 | ShaderRunnerTest::DepthTestEnable { enable } => {
1125 | writeln!(shader_test_file, "depthTestEnable {enable}").map_err(io_err)?;
1126 | }
1127 | ShaderRunnerTest::DepthWriteEnable { enable } => {
1128 | writeln!(shader_test_file, "depthWriteEnable {enable}").map_err(io_err)?;
1129 | }
1130 | ShaderRunnerTest::DepthCompareOp { op } => {
1131 | writeln!(shader_test_file, "depthCompareOp {op}").map_err(io_err)?;
1132 | }
1133 | ShaderRunnerTest::StencilTestEnable { enable } => {
1134 | writeln!(shader_test_file, "stencilTestEnable {enable}").map_err(io_err)?;
1135 | }
1136 | ShaderRunnerTest::FrontFace { mode } => {
1137 | writeln!(shader_test_file, "frontFace {mode}").map_err(io_err)?;
1138 | }
1139 | ShaderRunnerTest::StencilOp {
1140 | face,
1141 | op_name,
1142 | value,
1143 | } => {
1144 | writeln!(shader_test_file, "{face}.{op_name} {value}",).map_err(io_err)?;
1145 | }
1146 | ShaderRunnerTest::StencilReference { face, value } => {
1147 | writeln!(shader_test_file, "{face}.reference {value}",).map_err(io_err)?;
1148 | }
1149 | ShaderRunnerTest::StencilCompareOp { face, op } => {
1150 | writeln!(shader_test_file, "{face}.compareOp {op}",).map_err(io_err)?;
1151 | }
1152 | ShaderRunnerTest::ColorWriteMask { mask } => {
1153 | writeln!(shader_test_file, "colorWriteMask {mask}",).map_err(io_err)?;
1154 | }
1155 | ShaderRunnerTest::LogicOpEnable { enable } => {
1156 | writeln!(shader_test_file, "logicOpEnable {enable}",).map_err(io_err)?;
1157 | }
1158 | ShaderRunnerTest::LogicOp { op } => {
1159 | writeln!(shader_test_file, "logicOp {op}",).map_err(io_err)?;
1160 | }
1161 | ShaderRunnerTest::CullMode { mode } => {
1162 | writeln!(shader_test_file, "cullMode {mode}",).map_err(io_err)?;
1163 | }
1164 | ShaderRunnerTest::LineWidth { width } => {
1165 | writeln!(shader_test_file, "lineWidth {width}").map_err(io_err)?;
1166 | }
1167 | ShaderRunnerTest::Require {
1168 | feature,
1169 | parameters,
1170 | } => {
1171 | write!(shader_test_file, "require {feature}").map_err(io_err)?;
1172 | for param in parameters {
1173 | write!(shader_test_file, " {param}").map_err(io_err)?;
1174 | }
1175 | writeln!(shader_test_file).map_err(io_err)?;
1176 | }
1177 | }
1178 | }
1179 |
1180 | shader_test_file.flush().map_err(io_err)?;
1181 |
1182 | let tmp_image_path = "/tmp/vkrunner_output.ppm";
1183 | if let Some(output_path) = &request.output_path {
1184 | if output_path.starts_with("/tmp") {
1185 | tmp_image_path.to_string()
1186 | } else {
1187 | format!("/tmp/{output_path}")
1188 | }
1189 | } else {
1190 | tmp_image_path.to_string()
1191 | };
1192 | let mut vkrunner_args = vec![shader_test_path];
1193 |
1194 | if request.output_path.is_some() {
1195 | vkrunner_args.push("--image");
1196 | vkrunner_args.push(tmp_image_path);
1197 | }
1198 |
1199 | let vkrunner_output = Command::new("vkrunner")
1200 | .args(&vkrunner_args)
1201 | .stdout(Stdio::piped())
1202 | .stderr(Stdio::piped())
1203 | .output()
1204 | .map_err(|e| {
1205 | McpError::internal_error(
1206 | "Failed to run vkrunner",
1207 | Some(json!({"error": e.to_string()})),
1208 | )
1209 | })?;
1210 |
1211 | let stdout = String::from_utf8_lossy(&vkrunner_output.stdout).to_string();
1212 | let stderr = String::from_utf8_lossy(&vkrunner_output.stderr).to_string();
1213 |
1214 | let mut result_message = if vkrunner_output.status.success() {
1215 | format!(
1216 | "Shader compilation successful using shaderc-rs.\nVkRunner execution successful.\n\nOutput:\n{stdout}\n\n"
1217 | )
1218 | } else {
1219 | format!(
1220 | "Shader compilation successful using shaderc-rs.\nVkRunner execution failed.\n\nOutput:\n{stdout}\n\nError:\n{stderr}\n\n",
1221 | )
1222 | };
1223 |
1224 | if let Some(output_path) = &request.output_path {
1225 | if vkrunner_output.status.success() && Path::new(tmp_image_path).exists() {
1226 | match read_and_decode_ppm_file(tmp_image_path) {
1227 | Ok(img) => {
1228 | if let Some(parent) = Path::new(output_path).parent() {
1229 | if !parent.as_os_str().is_empty() {
1230 | std::fs::create_dir_all(parent).map_err(|e| {
1231 | McpError::internal_error(
1232 | "Failed to create output directory",
1233 | Some(json!({"error": e.to_string()})),
1234 | )
1235 | })?;
1236 | }
1237 | }
1238 |
1239 | img.save(output_path).map_err(|e| {
1240 | McpError::internal_error(
1241 | "Failed to save output image",
1242 | Some(json!({"error": e.to_string()})),
1243 | )
1244 | })?;
1245 |
1246 | result_message.push_str(&format!("Image saved to: {output_path}\n"));
1247 | }
1248 | Err(e) => {
1249 | result_message.push_str(&format!("Failed to convert output image: {e}\n"));
1250 | }
1251 | }
1252 | } else if vkrunner_output.status.success() {
1253 | result_message.push_str("No output image was generated by VkRunner.\n");
1254 | }
1255 | }
1256 |
1257 | result_message.push_str("\nShader Test File Contents:\n");
1258 | result_message.push_str(
1259 | &std::fs::read_to_string(shader_test_path)
1260 | .unwrap_or_else(|_| "Failed to read shader test file".to_string()),
1261 | );
1262 |
1263 | Ok(CallToolResult::success(vec![Content::text(result_message)]))
1264 | }
1265 | }
1266 |
1267 | impl Default for ShadercVkrunnerMcp {
1268 | fn default() -> Self {
1269 | Self::new()
1270 | }
1271 | }
1272 |
1273 | const_string!(Echo = "echo");
1274 | #[tool(tool_box)]
1275 | impl ServerHandler for ShadercVkrunnerMcp {
1276 | fn get_info(&self) -> ServerInfo {
1277 | ServerInfo {
1278 | protocol_version: ProtocolVersion::V_2024_11_05,
1279 | capabilities: ServerCapabilities::builder()
1280 | .enable_tools()
1281 | .build(),
1282 | server_info: Implementation::from_build_env(),
1283 | instructions: Some("This server provides tools for compiling and running GLSL shaders using Vulkan infrastructure. The typical workflow is:
1284 |
1285 | 1. Compile GLSL source code to SPIR-V assembly using the 'compile_run_shaders' tool
1286 | 2. Specify hardware/feature requirements if needed (e.g., geometry shaders, floating-point formats)
1287 | 3. Reference compiled SPIR-V files in shader passes
1288 | 4. Define vertex data if rendering geometry
1289 | 5. Set up test commands to draw or compute
1290 | 6. Optionally save the rendered output as an image".to_string()),
1291 | }
1292 | }
1293 | }
1294 |
1295 | #[derive(Parser, Debug)]
1296 | #[command(about)]
1297 | struct Args {
1298 | #[clap(short, long, value_parser)]
1299 | work_dir: Option<PathBuf>,
1300 | }
1301 |
1302 | #[tokio::main]
1303 | async fn main() -> Result<()> {
1304 | let args = Args::parse();
1305 |
1306 | if let Some(work_dir) = args.work_dir {
1307 | std::env::set_current_dir(&work_dir)
1308 | .map_err(|e| {
1309 | eprintln!("Failed to set working directory to {work_dir:?}: {e}");
1310 | std::process::exit(1);
1311 | })
1312 | .unwrap();
1313 | }
1314 |
1315 | tracing_subscriber::fmt()
1316 | .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
1317 | .with_writer(std::io::stderr)
1318 | .with_ansi(false)
1319 | .init();
1320 |
1321 | tracing::info!("Starting MCP server");
1322 |
1323 | let service = ShadercVkrunnerMcp::new()
1324 | .serve(stdio())
1325 | .await
1326 | .inspect_err(|e| {
1327 | tracing::error!("serving error: {:?}", e);
1328 | })?;
1329 |
1330 | service.waiting().await?;
1331 | Ok(())
1332 | }
1333 |
```