From 7e93be389a7b18f701a2eb19a3fb0b55751c50bb Mon Sep 17 00:00:00 2001 From: HavenCTO Date: Sun, 25 Jan 2026 18:18:22 -0500 Subject: [PATCH] initial commit of AHavenVLMConnector --- plugins/AHavenVLMConnector/CHANGELOG.md | 8 + plugins/AHavenVLMConnector/README.md | 143 +++++ .../AHavenVLMConnector/ahavenvlmconnector.yml | 22 + plugins/AHavenVLMConnector/exit_tracker.py | 98 +++ .../AHavenVLMConnector/haven_media_handler.py | 333 ++++++++++ .../AHavenVLMConnector/haven_vlm_config.py | 445 +++++++++++++ .../AHavenVLMConnector/haven_vlm_connector.py | 444 +++++++++++++ .../AHavenVLMConnector/haven_vlm_engine.py | 299 +++++++++ .../AHavenVLMConnector/haven_vlm_utility.py | 316 +++++++++ plugins/AHavenVLMConnector/requirements.txt | 8 + plugins/AHavenVLMConnector/run_tests.py | 110 ++++ .../test_dependency_management.py | 98 +++ .../test_haven_media_handler.py | 387 +++++++++++ .../test_haven_vlm_config.py | 286 +++++++++ .../test_haven_vlm_connector.py | 451 +++++++++++++ .../test_haven_vlm_engine.py | 544 ++++++++++++++++ .../test_haven_vlm_utility.py | 604 ++++++++++++++++++ 17 files changed, 4596 insertions(+) create mode 100644 plugins/AHavenVLMConnector/CHANGELOG.md create mode 100644 plugins/AHavenVLMConnector/README.md create mode 100644 plugins/AHavenVLMConnector/ahavenvlmconnector.yml create mode 100644 plugins/AHavenVLMConnector/exit_tracker.py create mode 100644 plugins/AHavenVLMConnector/haven_media_handler.py create mode 100644 plugins/AHavenVLMConnector/haven_vlm_config.py create mode 100644 plugins/AHavenVLMConnector/haven_vlm_connector.py create mode 100644 plugins/AHavenVLMConnector/haven_vlm_engine.py create mode 100644 plugins/AHavenVLMConnector/haven_vlm_utility.py create mode 100644 plugins/AHavenVLMConnector/requirements.txt create mode 100644 plugins/AHavenVLMConnector/run_tests.py create mode 100644 plugins/AHavenVLMConnector/test_dependency_management.py create mode 100644 plugins/AHavenVLMConnector/test_haven_media_handler.py create mode 100644 plugins/AHavenVLMConnector/test_haven_vlm_config.py create mode 100644 plugins/AHavenVLMConnector/test_haven_vlm_connector.py create mode 100644 plugins/AHavenVLMConnector/test_haven_vlm_engine.py create mode 100644 plugins/AHavenVLMConnector/test_haven_vlm_utility.py diff --git a/plugins/AHavenVLMConnector/CHANGELOG.md b/plugins/AHavenVLMConnector/CHANGELOG.md new file mode 100644 index 00000000..9fbe9a2a --- /dev/null +++ b/plugins/AHavenVLMConnector/CHANGELOG.md @@ -0,0 +1,8 @@ +# Changelog + +All notable changes to the A Haven VLM Connector project will be documented in this file. + +## [1.0.0] - 2025-06-29 + +### Added +- **Initial release** diff --git a/plugins/AHavenVLMConnector/README.md b/plugins/AHavenVLMConnector/README.md new file mode 100644 index 00000000..f5abf8ab --- /dev/null +++ b/plugins/AHavenVLMConnector/README.md @@ -0,0 +1,143 @@ +# A Haven VLM Connector + +A StashApp plugin for Vision-Language Model (VLM) based content tagging and analysis. This plugin is designed with a **local-first philosophy**, empowering users to run analysis on their own hardware (using CPU or GPU) and their local network. It also supports cloud-based VLM endpoints for additional flexibility. The Haven VLM Engine provides advanced automatic content detection and tagging, delivering superior accuracy compared to traditional image classification methods. + +## Features + +- **Local Network Empowerment**: Distribute processing across home/office computers without cloud dependencies +- **Context-Aware Detection**: Leverages Vision-Language Models' understanding of visual relationships +- **Advanced Dependency Management**: Uses PythonDepManager for automatic dependency installation +- **Enjoying Funscript Haven?** Check out more tools and projects at https://github.com/Haven-hvn + +## Requirements + +- Python 3.8+ +- StashApp +- PythonDepManager plugin (automatically handles dependencies) +- OpenAI-compatible VLM endpoints (local or cloud-based) + +## Installation + +1. Clone or download this plugin to your StashApp plugins directory +2. Ensure PythonDepManager is installed in your StashApp plugins +3. Configure your VLM endpoints in `haven_vlm_config.py` (local network endpoints recommended) +4. Restart StashApp + +The plugin automatically manages all dependencies. + +## Why Local-First? + +- **Complete Control**: Process sensitive content on your own hardware +- **Cost Effective**: Avoid cloud processing fees by using existing resources +- **Flexible Scaling**: Add more computers to your local network for increased capacity +- **Privacy Focused**: Keep your media completely private +- **Hybrid Options**: Combine local and cloud endpoints for optimal flexibility + +```mermaid +graph LR +A[User's Computer] --> B[Local GPU Machine] +A --> C[Local CPU Machine 1] +A --> D[Local CPU Machine 2] +A --> E[Cloud Endpoint] +``` + +## Configuration + +### Easy Setup with LM Studio + +[LM Studio](https://lmstudio.ai/) provides the easiest way to configure local endpoints: + +1. Download and install [LM Studio](https://lmstudio.ai/) +2. [Search for or download](https://huggingface.co/models) a vision-capable model; tested with : (in order of high to low accuracy) zai-org/glm-4.6v-flash, huihui-mistral-small-3.2-24b-instruct-2506-abliterated-v2, qwen/qwen3-vl-8b, lfm2.5-vl +3. Load your desired Model +4. On the developer tab start the local server using the start toggle +5. Optionally click the Settings gear then toggle *Serve on local network* +5. Optionally configure `haven_vlm_config.py`: + +By default locahost is included in the config, **remove cloud endpoint if you don't want automatic failover** +```python +{ + "base_url": "http://localhost:1234/v1", # LM Studio default + "api_key": "", # API key not required + "name": "lm-studio-local", + "weight": 5, + "is_fallback": False +} +``` + +### Tag Configuration + +```python +"tag_list": [ + "Basketball point", "Foul", "Break-away", "Turnover" +] +``` + +### Processing Settings + +```python +VIDEO_FRAME_INTERVAL = 2.0 # Process every 2 seconds +CONCURRENT_TASK_LIMIT = 8 # Adjust based on local hardware +``` + +## Usage + +### Tag Videos +1. Tag scenes with `VLM_TagMe` +2. Run "Tag Videos" task +3. Plugin processes content using local/network resources + +### Performance Tips +- Start with 2-3 local machines for load balancing +- Assign higher weights to GPU-enabled machines +- Adjust `CONCURRENT_TASK_LIMIT` based on total system resources +- Use SSD storage for better I/O performance + +## File Structure + +``` +AHavenVLMConnector/ +├── ahavenvlmconnector.yml +├── haven_vlm_connector.py +├── haven_vlm_config.py +├── haven_vlm_engine.py +├── haven_media_handler.py +├── haven_vlm_utility.py +├── requirements.txt +└── README.md +``` + +## Troubleshooting + +### Local Network Setup +- Ensure firewalls allow communication between machines +- Verify all local endpoints are running VLM services +- Use static IPs for local machines +- Check `http://local-machine-ip:port/v1` responds correctly + +### Performance Optimization +- **Distribute Load**: Use multiple mid-range machines instead of one high-end +- **GPU Prioritization**: Assign highest weight to GPU machines +- **Network Speed**: Use wired Ethernet connections for faster transfer +- **Resource Monitoring**: Watch system resources during processing + +## Development + +### Adding Local Endpoints +1. Install VLM service on network machines +2. Add endpoint configuration with local IPs +3. Set appropriate weights based on hardware capability + +### Custom Models +Use any OpenAI-compatible models that support: +- POST requests to `/v1/chat/completions` +- Vision capabilities with image input +- Local deployment options + +### Log Messages + +Check StashApp logs for detailed processing information and error messages. + +## License + +This project is part of the StashApp Community Scripts collection. \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/ahavenvlmconnector.yml b/plugins/AHavenVLMConnector/ahavenvlmconnector.yml new file mode 100644 index 00000000..aa142a8d --- /dev/null +++ b/plugins/AHavenVLMConnector/ahavenvlmconnector.yml @@ -0,0 +1,22 @@ +name: A Haven VLM Connector +# requires: PythonDepManager +description: Tag videos with Vision-Language Models using any OpenAI-compatible VLM endpoint +version: 1.0.0 +url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AHavenVLMConnector +exec: + - python + - "{pluginDir}/haven_vlm_connector.py" +interface: raw +tasks: + - name: Tag Videos + description: Run VLM analysis on videos with VLM_TagMe tag + defaultArgs: + mode: tag_videos + - name: Collect Incorrect Markers and Images + description: Collects data from markers and images that were VLM tagged but were manually marked with VLM_Incorrect due to the VLM making a mistake. This will collect the data and output as a file which can be used to improve the VLM models. + defaultArgs: + mode: collect_incorrect_markers + - name: Find Marker Settings + description: Find Optimal Marker Settings based on a video that has manually tuned markers and has been processed by the VLM previously. Only 1 video should have VLM_TagMe before running. + defaultArgs: + mode: find_marker_settings \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/exit_tracker.py b/plugins/AHavenVLMConnector/exit_tracker.py new file mode 100644 index 00000000..74a4cea8 --- /dev/null +++ b/plugins/AHavenVLMConnector/exit_tracker.py @@ -0,0 +1,98 @@ +""" +Comprehensive sys.exit tracking module +Instruments all sys.exit() calls with full call stack and context +""" + +import sys +import traceback +from typing import Optional + +# Store original sys.exit +original_exit = sys.exit + +# Track if we've already patched +_exit_tracker_patched = False + +def install_exit_tracker(logger=None) -> None: + """ + Install the exit tracker by monkey-patching sys.exit + + Args: + logger: Optional logger instance (will use fallback print if None) + """ + global _exit_tracker_patched, original_exit + + if _exit_tracker_patched: + return + + # Store original if not already stored + if hasattr(sys, 'exit') and sys.exit is not original_exit: + original_exit = sys.exit + + def tracked_exit(code: int = 0) -> None: + """Track sys.exit() calls with full call stack""" + # Get current stack trace (not from exception, but current call stack) + stack = traceback.extract_stack() + + # Format the stack trace, excluding this tracking function + stack_lines = [] + for frame in stack: + # Skip internal Python frames and this tracker + if ('tracked_exit' not in frame.filename and + '/usr/lib' not in frame.filename and + '/System/Library' not in frame.filename and + 'exit_tracker.py' not in frame.filename): + stack_lines.append( + f" File \"{frame.filename}\", line {frame.lineno}, in {frame.name}\n {frame.line}" + ) + + # Take last 15 frames to see the full call chain + stack_str = '\n'.join(stack_lines[-15:]) + + # Get current exception info if available + exc_info = sys.exc_info() + exc_str = "" + if exc_info[0] is not None: + exc_str = f"\n Active Exception: {exc_info[0].__name__}: {exc_info[1]}" + + # Build the error message + error_msg = f"""[DEBUG_EXIT_CODE] ========================================== +[DEBUG_EXIT_CODE] sys.exit() called with code: {code} +[DEBUG_EXIT_CODE] Call stack (last 15 frames): +{stack_str} +{exc_str} +[DEBUG_EXIT_CODE] ==========================================""" + + # Log using provided logger or fallback to print + if logger: + try: + logger.error(error_msg) + except Exception as log_error: + print(f"[EXIT_TRACKER_LOGGER_ERROR] Failed to log: {log_error}") + print(error_msg) + else: + print(error_msg) + + # Call original exit + original_exit(code) + + # Install the tracker + sys.exit = tracked_exit + _exit_tracker_patched = True + + if logger: + logger.debug("[DEBUG_EXIT_CODE] Exit tracker installed successfully") + else: + print("[DEBUG_EXIT_CODE] Exit tracker installed successfully") + +def uninstall_exit_tracker() -> None: + """Uninstall the exit tracker and restore original sys.exit""" + global _exit_tracker_patched, original_exit + + if _exit_tracker_patched: + sys.exit = original_exit + _exit_tracker_patched = False + +# Auto-install on import (can be disabled by calling uninstall_exit_tracker()) +if not _exit_tracker_patched: + install_exit_tracker() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/haven_media_handler.py b/plugins/AHavenVLMConnector/haven_media_handler.py new file mode 100644 index 00000000..163562a4 --- /dev/null +++ b/plugins/AHavenVLMConnector/haven_media_handler.py @@ -0,0 +1,333 @@ +""" +Haven Media Handler Module +Handles StashApp media operations and tag management +""" + +import os +import zipfile +import shutil +from typing import List, Dict, Any, Optional, Tuple, Set +from datetime import datetime +import json + +# Use PythonDepManager for dependency management +try: + from PythonDepManager import ensure_import + ensure_import("stashapi:stashapp-tools==0.2.58") + + from stashapi.stashapp import StashInterface, StashVersion + import stashapi.log as log +except ImportError as e: + print(f"stashapp-tools not found: {e}") + print("Please ensure PythonDepManager is available and stashapp-tools is accessible") + raise + +import haven_vlm_config as config + +# Global variables +tag_id_cache: Dict[str, int] = {} +vlm_tag_ids_cache: Set[int] = set() +stash_version: Optional[StashVersion] = None +end_seconds_support: bool = False + +# Tag IDs +stash: Optional[StashInterface] = None +vlm_errored_tag_id: Optional[int] = None +vlm_tagme_tag_id: Optional[int] = None +vlm_base_tag_id: Optional[int] = None +vlm_tagged_tag_id: Optional[int] = None +vr_tag_id: Optional[int] = None +vlm_incorrect_tag_id: Optional[int] = None + +def initialize(connection: Dict[str, Any]) -> None: + """Initialize the media handler with StashApp connection""" + global stash, vlm_errored_tag_id, vlm_tagme_tag_id, vlm_base_tag_id + global vlm_tagged_tag_id, vr_tag_id, end_seconds_support, stash_version + global vlm_incorrect_tag_id + + # Initialize the Stash API + stash = StashInterface(connection) + + # Initialize "metadata" tags + vlm_errored_tag_id = stash.find_tag(config.config.vlm_errored_tag_name, create=True)["id"] + vlm_tagme_tag_id = stash.find_tag(config.config.vlm_tagme_tag_name, create=True)["id"] + vlm_base_tag_id = stash.find_tag(config.config.vlm_base_tag_name, create=True)["id"] + vlm_tagged_tag_id = stash.find_tag(config.config.vlm_tagged_tag_name, create=True)["id"] + vlm_incorrect_tag_id = stash.find_tag(config.config.vlm_incorrect_tag_name, create=True)["id"] + + # Get VR tag from configuration + vr_tag_name = stash.get_configuration()["ui"].get("vrTag", None) + if not vr_tag_name: + log.warning("No VR tag found in configuration") + vr_tag_id = None + else: + vr_tag_id = stash.find_tag(vr_tag_name)["id"] + + stash_version = get_stash_version() + end_second_support_beyond = StashVersion("v0.27.2-76648") + end_seconds_support = stash_version > end_second_support_beyond + +def get_stash_version() -> StashVersion: + """Get the current StashApp version""" + if not stash: + raise RuntimeError("Stash interface not initialized") + return stash.stash_version() + +# ----------------- Tag Management Methods ----------------- + +def get_tag_ids(tag_names: List[str], create: bool = False) -> List[int]: + """Get tag IDs for multiple tag names""" + return [get_tag_id(tag_name, create) for tag_name in tag_names] + +def get_tag_id(tag_name: str, create: bool = False) -> Optional[int]: + """Get tag ID for a single tag name""" + if tag_name not in tag_id_cache: + stashtag = stash.find_tag(tag_name) + if stashtag: + tag_id_cache[tag_name] = stashtag["id"] + return stashtag["id"] + else: + if not create: + return None + tag = stash.create_tag({ + "name": tag_name, + "ignore_auto_tag": True, + "parent_ids": [vlm_base_tag_id] + })['id'] + tag_id_cache[tag_name] = tag + vlm_tag_ids_cache.add(tag) + return tag + return tag_id_cache.get(tag_name) + +def get_vlm_tags() -> List[int]: + """Get all VLM-generated tags""" + if len(vlm_tag_ids_cache) == 0: + vlm_tags = [ + item['id'] for item in stash.find_tags( + f={"parents": {"value": vlm_base_tag_id, "modifier": "INCLUDES"}}, + fragment="id" + ) + ] + vlm_tag_ids_cache.update(vlm_tags) + else: + vlm_tags = list(vlm_tag_ids_cache) + return vlm_tags + +def is_scene_tagged(tags: List[Dict[str, Any]]) -> bool: + """Check if a scene has been tagged by VLM""" + for tag in tags: + if tag['id'] == vlm_tagged_tag_id: + return True + return False + +def is_vr_scene(tags: List[Dict[str, Any]]) -> bool: + """Check if a scene is VR content""" + for tag in tags: + if tag['id'] == vr_tag_id: + return True + return False + +# ----------------- Scene Management Methods ----------------- + +def add_tags_to_video(video_id: int, tag_ids: List[int], add_tagged: bool = True) -> None: + """Add tags to a video scene""" + if add_tagged: + tag_ids.append(vlm_tagged_tag_id) + stash.update_scenes({ + "ids": [video_id], + "tag_ids": {"ids": tag_ids, "mode": "ADD"} + }) + +def clear_all_tags_from_video(scene: Dict[str, Any]) -> None: + """Clear all tags from a video scene using existing scene data""" + scene_id = scene.get('id') + if scene_id is None: + log.error("Scene missing 'id' field") + return + + current_tag_ids = [tag['id'] for tag in scene.get('tags', [])] + if current_tag_ids: + stash.update_scenes({ + "ids": [scene_id], + "tag_ids": {"ids": current_tag_ids, "mode": "REMOVE"} + }) + log.info(f"Cleared {len(current_tag_ids)} tags from scene {scene_id}") + +def clear_all_markers_from_video(video_id: int) -> None: + """Clear all markers from a video scene""" + markers = get_scene_markers(video_id) + if markers: + delete_markers(markers) + log.info(f"Cleared all {len(markers)} markers from scene {video_id}") + +def remove_vlm_tags_from_video( + video_id: int, + remove_tagme: bool = True, + remove_errored: bool = True +) -> None: + """Remove all VLM tags from a video scene""" + vlm_tags = get_vlm_tags() + if remove_tagme: + vlm_tags.append(vlm_tagme_tag_id) + if remove_errored: + vlm_tags.append(vlm_errored_tag_id) + stash.update_scenes({ + "ids": [video_id], + "tag_ids": {"ids": vlm_tags, "mode": "REMOVE"} + }) + +def get_tagme_scenes() -> List[Dict[str, Any]]: + """Get scenes tagged with VLM_TagMe""" + return stash.find_scenes( + f={"tags": {"value": vlm_tagme_tag_id, "modifier": "INCLUDES"}}, + fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}" + ) + +def add_error_scene(scene_id: int) -> None: + """Add error tag to a scene""" + stash.update_scenes({ + "ids": [scene_id], + "tag_ids": {"ids": [vlm_errored_tag_id], "mode": "ADD"} + }) + +def remove_tagme_tag_from_scene(scene_id: int) -> None: + """Remove VLM_TagMe tag from a scene""" + stash.update_scenes({ + "ids": [scene_id], + "tag_ids": {"ids": [vlm_tagme_tag_id], "mode": "REMOVE"} + }) + +# ----------------- Marker Management Methods ----------------- + +def add_markers_to_video_from_dict( + video_id: int, + tag_timespans_dict: Dict[str, Dict[str, List[Any]]] +) -> None: + """Add markers to video from timespan dictionary""" + for _, tag_timespan_dict in tag_timespans_dict.items(): + for tag_name, time_frames in tag_timespan_dict.items(): + tag_id = get_tag_id(tag_name, create=True) + if tag_id: + add_markers_to_video(video_id, tag_id, tag_name, time_frames) + +def get_incorrect_markers() -> List[Dict[str, Any]]: + """Get markers tagged with VLM_Incorrect""" + if end_seconds_support: + return stash.find_scene_markers( + {"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}}, + fragment="id scene {id files{path}} primary_tag {id, name} seconds end_seconds" + ) + else: + return stash.find_scene_markers( + {"tags": {"value": vlm_incorrect_tag_id, "modifier": "INCLUDES"}}, + fragment="id scene {id files{path}} primary_tag {id, name} seconds" + ) + +def add_markers_to_video( + video_id: int, + tag_id: int, + tag_name: str, + time_frames: List[Any] +) -> None: + """Add markers to video for specific time frames""" + for time_frame in time_frames: + if end_seconds_support: + stash.create_scene_marker({ + "scene_id": video_id, + "primary_tag_id": tag_id, + "tag_ids": [tag_id], + "seconds": time_frame.start, + "end_seconds": time_frame.end, + "title": tag_name + }) + else: + stash.create_scene_marker({ + "scene_id": video_id, + "primary_tag_id": tag_id, + "tag_ids": [tag_id], + "seconds": time_frame.start, + "title": tag_name + }) + +def get_scene_markers(video_id: int) -> List[Dict[str, Any]]: + """Get all markers for a scene""" + return stash.get_scene_markers(video_id) + +def write_scene_marker_to_file( + marker: Dict[str, Any], + scene_file: str, + output_folder: str +) -> None: + """Write scene marker data to file for analysis""" + try: + marker_id = marker['id'] + scene_id = marker['scene']['id'] + tag_name = marker['primary_tag']['name'] + + # Create output filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"marker_{marker_id}_scene_{scene_id}_{tag_name}_{timestamp}.json" + output_path = os.path.join(output_folder, filename) + + # Prepare marker data + marker_data = { + "marker_id": marker_id, + "scene_id": scene_id, + "tag_name": tag_name, + "seconds": marker.get("seconds"), + "end_seconds": marker.get("end_seconds"), + "scene_file": scene_file, + "timestamp": timestamp + } + + # Write to file + with open(output_path, 'w') as f: + json.dump(marker_data, f, indent=2) + + except Exception as e: + log.error(f"Failed to write marker data: {e}") + +def delete_markers(markers: List[Dict[str, Any]]) -> None: + """Delete markers from StashApp""" + for marker in markers: + try: + stash.destroy_scene_marker(marker['id']) + except Exception as e: + log.error(f"Failed to delete marker {marker['id']}: {e}") + +def get_scene_markers_by_tag( + video_id: int, + error_if_no_end_seconds: bool = True +) -> List[Dict[str, Any]]: + """Get scene markers by tag with end_seconds support check""" + if end_seconds_support: + return stash.get_scene_markers(video_id) + else: + if error_if_no_end_seconds: + log.error("End seconds not supported in this StashApp version") + raise RuntimeError("End seconds not supported") + return stash.get_scene_markers(video_id) + +def remove_incorrect_tag_from_markers(markers: List[Dict[str, Any]]) -> None: + """Remove VLM_Incorrect tag from markers""" + marker_ids = [marker['id'] for marker in markers] + for marker_id in marker_ids: + try: + stash.update_scene_marker({ + "id": marker_id, + "tag_ids": {"ids": [vlm_incorrect_tag_id], "mode": "REMOVE"} + }) + except Exception as e: + log.error(f"Failed to remove incorrect tag from marker {marker_id}: {e}") + +def remove_vlm_markers_from_video(video_id: int) -> None: + """Remove all VLM markers from a video""" + markers = get_scene_markers(video_id) + vlm_tag_ids = get_vlm_tags() + + for marker in markers: + if marker['primary_tag']['id'] in vlm_tag_ids: + try: + stash.destroy_scene_marker(marker['id']) + except Exception as e: + log.error(f"Failed to delete VLM marker {marker['id']}: {e}") \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/haven_vlm_config.py b/plugins/AHavenVLMConnector/haven_vlm_config.py new file mode 100644 index 00000000..1e5bc69a --- /dev/null +++ b/plugins/AHavenVLMConnector/haven_vlm_config.py @@ -0,0 +1,445 @@ +""" +Configuration for A Haven VLM Connector +A StashApp plugin for Vision-Language Model based content tagging +""" + +from typing import Dict, List, Optional +from dataclasses import dataclass +import os +import yaml + +# ----------------- Core Settings ----------------- + +# VLM Engine Configuration +VLM_ENGINE_CONFIG = { + "active_ai_models": ["vlm_multiplexer_model"], + "pipelines": { + "video_pipeline_dynamic": { + "inputs": [ + "video_path", + "return_timestamps", + "time_interval", + "threshold", + "return_confidence", + "vr_video", + "existing_video_data", + "skipped_categories", + ], + "output": "results", + "short_name": "dynamic_video", + "version": 1.0, + "models": [ + { + "name": "dynamic_video_ai", + "inputs": [ + "video_path", "return_timestamps", "time_interval", + "threshold", "return_confidence", "vr_video", + "existing_video_data", "skipped_categories" + ], + "outputs": "results", + }, + ], + } + }, + "models": { + "binary_search_processor_dynamic": { + "type": "binary_search_processor", + "model_file_name": "binary_search_processor_dynamic" + }, + "vlm_multiplexer_model": { + "type": "vlm_model", + "model_file_name": "vlm_multiplexer_model", + "model_category": "actiondetection", + "model_id": "zai-org/glm-4.6v-flash", + "model_identifier": 93848, + "model_version": "1.0", + "use_multiplexer": True, + "max_concurrent_requests": 13, + "instance_count": 10, + "max_batch_size": 4, + "multiplexer_endpoints": [ + { + "base_url": "http://localhost:1234/v1", + "api_key": "", + "name": "lm-studio-primary", + "weight": 9, + "is_fallback": False, + "max_concurrent": 10 + }, + { + "base_url": "https://cloudagnostic.com:443/v1", + "api_key": "", + "name": "cloud-fallback", + "weight": 1, + "is_fallback": True, + "max_concurrent": 2 + } + ], + "tag_list": [ + "Anal Fucking", "Ass Licking", "Ass Penetration", "Ball Licking/Sucking", "Blowjob", "Cum on Person", + "Cum Swapping", "Cumshot", "Deepthroat", "Double Penetration", "Fingering", "Fisting", "Footjob", + "Gangbang", "Gloryhole", "Grabbing Ass", "Grabbing Boobs", "Grabbing Hair/Head", "Handjob", "Kissing", + "Licking Penis", "Masturbation", "Pissing", "Pussy Licking (Clearly Visible)", "Pussy Licking", + "Pussy Rubbing", "Sucking Fingers", "Sucking Toy/Dildo", "Wet (Genitals)", "Titjob", "Tribbing/Scissoring", + "Undressing", "Vaginal Penetration", "Vaginal Fucking", "Vibrating" + ] + }, + "result_coalescer": { + "type": "python", + "model_file_name": "result_coalescer" + }, + "result_finisher": { + "type": "python", + "model_file_name": "result_finisher" + }, + "batch_awaiter": { + "type": "python", + "model_file_name": "batch_awaiter" + }, + "video_result_postprocessor": { + "type": "python", + "model_file_name": "video_result_postprocessor" + }, + }, + "category_config": { + "actiondetection": { + "69": { + "RenamedTag": "69", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Anal Fucking": { + "RenamedTag": "Anal Fucking", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Ass Licking": { + "RenamedTag": "Ass Licking", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Ass Penetration": { + "RenamedTag": "Ass Penetration", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Ball Licking/Sucking": { + "RenamedTag": "Ball Licking/Sucking", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Blowjob": { + "RenamedTag": "Blowjob", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Cum on Person": { + "RenamedTag": "Cum on Person", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Cum Swapping": { + "RenamedTag": "Cum Swapping", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Cumshot": { + "RenamedTag": "Cumshot", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Deepthroat": { + "RenamedTag": "Deepthroat", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Double Penetration": { + "RenamedTag": "Double Penetration", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Fingering": { + "RenamedTag": "Fingering", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Fisting": { + "RenamedTag": "Fisting", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Footjob": { + "RenamedTag": "Footjob", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Gangbang": { + "RenamedTag": "Gangbang", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Gloryhole": { + "RenamedTag": "Gloryhole", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Grabbing Ass": { + "RenamedTag": "Grabbing Ass", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Grabbing Boobs": { + "RenamedTag": "Grabbing Boobs", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Grabbing Hair/Head": { + "RenamedTag": "Grabbing Hair/Head", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Handjob": { + "RenamedTag": "Handjob", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Kissing": { + "RenamedTag": "Kissing", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Licking Penis": { + "RenamedTag": "Licking Penis", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Masturbation": { + "RenamedTag": "Masturbation", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Pissing": { + "RenamedTag": "Pissing", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Pussy Licking (Clearly Visible)": { + "RenamedTag": "Pussy Licking (Clearly Visible)", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Pussy Licking": { + "RenamedTag": "Pussy Licking", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Pussy Rubbing": { + "RenamedTag": "Pussy Rubbing", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Sucking Fingers": { + "RenamedTag": "Sucking Fingers", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Sucking Toy/Dildo": { + "RenamedTag": "Sucking Toy/Dildo", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Wet (Genitals)": { + "RenamedTag": "Wet (Genitals)", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Titjob": { + "RenamedTag": "Titjob", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Tribbing/Scissoring": { + "RenamedTag": "Tribbing/Scissoring", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Undressing": { + "RenamedTag": "Undressing", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Vaginal Penetration": { + "RenamedTag": "Vaginal Penetration", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Vaginal Fucking": { + "RenamedTag": "Vaginal Fucking", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + }, + "Vibrating": { + "RenamedTag": "Vibrating", + "MinMarkerDuration": "1s", + "MaxGap": "30s", + "RequiredDuration": "1s", + "TagThreshold": 0.5, + } + } + } +} + +# ----------------- Processing Settings ----------------- + +# Video processing settings +VIDEO_FRAME_INTERVAL = 80 # Process every 80 seconds +VIDEO_THRESHOLD = 0.3 +VIDEO_CONFIDENCE_RETURN = True + +# Concurrency settings +CONCURRENT_TASK_LIMIT = 20 # Increased for better parallel video processing +SERVER_TIMEOUT = 3700 + +# ----------------- Tag Configuration ----------------- + +# Tag names for StashApp integration +VLM_BASE_TAG_NAME = "VLM" +VLM_TAGME_TAG_NAME = "VLM_TagMe" +VLM_UPDATEME_TAG_NAME = "VLM_UpdateMe" +VLM_TAGGED_TAG_NAME = "VLM_Tagged" +VLM_ERRORED_TAG_NAME = "VLM_Errored" +VLM_INCORRECT_TAG_NAME = "VLM_Incorrect" + +# ----------------- File System Settings ----------------- + +# Directory paths +OUTPUT_DATA_DIR = "./output_data" + +# File management +DELETE_INCORRECT_MARKERS = True +CREATE_MARKERS = True + +# Path mutations for different environments +PATH_MUTATION = {} + +# ----------------- Configuration Loading ----------------- + +@dataclass +class VLMConnectorConfig: + """Configuration class for the VLM Connector""" + vlm_engine_config: Dict + video_frame_interval: float + video_threshold: float + video_confidence_return: bool + concurrent_task_limit: int + server_timeout: int + vlm_base_tag_name: str + vlm_tagme_tag_name: str + vlm_updateme_tag_name: str + vlm_tagged_tag_name: str + vlm_errored_tag_name: str + vlm_incorrect_tag_name: str + output_data_dir: str + delete_incorrect_markers: bool + create_markers: bool + path_mutation: Dict + +def load_config_from_yaml(config_path: Optional[str] = None) -> VLMConnectorConfig: + """Load configuration from YAML file or use defaults""" + if config_path and os.path.exists(config_path): + with open(config_path, 'r') as f: + yaml_config = yaml.safe_load(f) + return VLMConnectorConfig(**yaml_config) + + # Return default configuration + return VLMConnectorConfig( + vlm_engine_config=VLM_ENGINE_CONFIG, + video_frame_interval=VIDEO_FRAME_INTERVAL, + video_threshold=VIDEO_THRESHOLD, + video_confidence_return=VIDEO_CONFIDENCE_RETURN, + concurrent_task_limit=CONCURRENT_TASK_LIMIT, + server_timeout=SERVER_TIMEOUT, + vlm_base_tag_name=VLM_BASE_TAG_NAME, + vlm_tagme_tag_name=VLM_TAGME_TAG_NAME, + vlm_updateme_tag_name=VLM_UPDATEME_TAG_NAME, + vlm_tagged_tag_name=VLM_TAGGED_TAG_NAME, + vlm_errored_tag_name=VLM_ERRORED_TAG_NAME, + vlm_incorrect_tag_name=VLM_INCORRECT_TAG_NAME, + output_data_dir=OUTPUT_DATA_DIR, + delete_incorrect_markers=DELETE_INCORRECT_MARKERS, + create_markers=CREATE_MARKERS, + path_mutation=PATH_MUTATION + ) + +# Global configuration instance +config = load_config_from_yaml() diff --git a/plugins/AHavenVLMConnector/haven_vlm_connector.py b/plugins/AHavenVLMConnector/haven_vlm_connector.py new file mode 100644 index 00000000..e6655d60 --- /dev/null +++ b/plugins/AHavenVLMConnector/haven_vlm_connector.py @@ -0,0 +1,444 @@ +""" +A Haven VLM Connector +A StashApp plugin for Vision-Language Model based content tagging +""" + +import os +import sys +import json +import shutil +import traceback +import asyncio +import logging +import time +from typing import Dict, Any, List, Optional +from datetime import datetime + +# Import and install sys.exit tracking FIRST (before any other imports that might call sys.exit) +try: + from exit_tracker import install_exit_tracker + import stashapi.log as log + install_exit_tracker(log) +except ImportError as e: + print(f"Warning: exit_tracker not available: {e}") + print("sys.exit tracking will not be available") + +# ----------------- Setup and Dependencies ----------------- + +# Use PythonDepManager for dependency management +try: + from PythonDepManager import ensure_import + + # Install and ensure all required dependencies with specific versions + ensure_import( + "stashapi:stashapp-tools==0.2.58", + "aiohttp==3.12.13", + "pydantic==2.11.7", + "vlm-engine==0.9.1", + "pyyaml==6.0.2" + ) + + # Import the dependencies after ensuring they're available + import stashapi.log as log + from stashapi.stashapp import StashInterface + import aiohttp + import pydantic + import yaml + +except ImportError as e: + print(f"Failed to import PythonDepManager or required dependencies: {e}") + print("Please ensure PythonDepManager is installed and available.") + sys.exit(1) +except Exception as e: + print(f"Error during dependency management: {e}") + print(f"Stack trace: {traceback.format_exc()}") + sys.exit(1) + +# Import local modules +try: + import haven_vlm_config as config +except ModuleNotFoundError: + log.error("Please provide a haven_vlm_config.py file with the required variables.") + raise Exception("Please provide a haven_vlm_config.py file with the required variables.") + +import haven_media_handler as media_handler +import haven_vlm_engine as vlm_engine +from haven_vlm_engine import TimeFrame + +log.debug("Python instance is running at: " + sys.executable) + +# ----------------- Global Variables ----------------- + +semaphore: Optional[asyncio.Semaphore] = None +progress: float = 0.0 +increment: float = 0.0 +completed_tasks: int = 0 +total_tasks: int = 0 +video_progress: Dict[str, float] = {} + +# ----------------- Main Execution ----------------- + +async def main() -> None: + """Main entry point for the plugin""" + global semaphore + + # Semaphore initialization logging for hypothesis A + log.debug(f"[DEBUG_HYPOTHESIS_A] Initializing semaphore with limit {config.config.concurrent_task_limit}") + + semaphore = asyncio.Semaphore(config.config.concurrent_task_limit) + + # Post-semaphore creation logging + log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore created successfully (limit: {config.config.concurrent_task_limit})") + + json_input = read_json_input() + output = {} + await run(json_input, output) + out = json.dumps(output) + print(out + "\n") + +def read_json_input() -> Dict[str, Any]: + """Read JSON input from stdin""" + json_input = sys.stdin.read() + return json.loads(json_input) + +async def run(json_input: Dict[str, Any], output: Dict[str, Any]) -> None: + """Main execution logic""" + plugin_args = None + try: + log.debug(json_input["server_connection"]) + os.chdir(json_input["server_connection"]["PluginDir"]) + media_handler.initialize(json_input["server_connection"]) + except Exception as e: + log.error(f"Failed to initialize media handler: {e}") + raise + + try: + plugin_args = json_input['args']["mode"] + except KeyError: + pass + + if plugin_args == "tag_videos": + await tag_videos() + output["output"] = "ok" + return + elif plugin_args == "find_marker_settings": + await find_marker_settings() + output["output"] = "ok" + return + elif plugin_args == "collect_incorrect_markers": + collect_incorrect_markers_and_images() + output["output"] = "ok" + return + + output["output"] = "ok" + return + +# ----------------- High Level Processing Functions ----------------- + +async def tag_videos() -> None: + """Tag videos with VLM analysis using improved async orchestration""" + global completed_tasks, total_tasks + + scenes = media_handler.get_tagme_scenes() + if not scenes: + log.info("No videos to tag. Have you tagged any scenes with the VLM_TagMe tag to get processed?") + return + + total_tasks = len(scenes) + completed_tasks = 0 + + video_progress.clear() + for scene in scenes: + video_progress[scene.get('id', 'unknown')] = 0.0 + log.progress(0.0) + + log.info(f"🚀 Starting video processing for {total_tasks} scenes with semaphore limit of {config.config.concurrent_task_limit}") + + # Create tasks with proper indexing for debugging + tasks = [] + for i, scene in enumerate(scenes): + # Pre-task creation logging for hypothesis A (semaphore deadlock) and E (signal termination) + scene_id = scene.get('id') + log.debug(f"[DEBUG_HYPOTHESIS_A] Creating task {i+1}/{total_tasks} for scene {scene_id}, semaphore limit: {config.config.concurrent_task_limit}") + + task = asyncio.create_task(__tag_video_with_timing(scene, i)) + tasks.append(task) + + # Use asyncio.as_completed to process results as they finish (proves concurrency) + completed_task_futures = asyncio.as_completed(tasks) + + batch_start_time = asyncio.get_event_loop().time() + + for completed_task in completed_task_futures: + try: + await completed_task + completed_tasks += 1 + + except Exception as e: + completed_tasks += 1 + # Exception logging for hypothesis E (signal termination) + error_type = type(e).__name__ + log.debug(f"[DEBUG_HYPOTHESIS_E] Task failed with exception: {error_type}: {str(e)} (Task {completed_tasks}/{total_tasks})") + + log.error(f"❌ Task failed: {e}") + + total_time = asyncio.get_event_loop().time() - batch_start_time + + log.info(f"🎉 All {total_tasks} videos completed in {total_time:.2f}s (avg: {total_time/total_tasks:.2f}s/video)") + log.progress(1.0) + +async def find_marker_settings() -> None: + """Find optimal marker settings based on a single tagged video""" + scenes = media_handler.get_tagme_scenes() + if len(scenes) != 1: + log.error("Please tag exactly one scene with the VLM_TagMe tag to get processed.") + return + scene = scenes[0] + await __find_marker_settings(scene) + +def collect_incorrect_markers_and_images() -> None: + """Collect data from incorrectly tagged markers and images""" + incorrect_images = media_handler.get_incorrect_images() + image_paths, image_ids, temp_files = media_handler.get_image_paths_and_ids(incorrect_images) + incorrect_markers = media_handler.get_incorrect_markers() + + if not (len(incorrect_images) > 0 or len(incorrect_markers) > 0): + log.info("No incorrect images or markers to collect.") + return + + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + try: + # Process images + image_folder = os.path.join(config.config.output_data_dir, "images") + os.makedirs(image_folder, exist_ok=True) + for image_path in image_paths: + try: + shutil.copy(image_path, image_folder) + except Exception as e: + log.error(f"Failed to copy image {image_path} to {image_folder}: {e}") + except Exception as e: + log.error(f"Failed to process images: {e}") + raise e + finally: + # Clean up temp files + for temp_file in temp_files: + try: + if os.path.isdir(temp_file): + shutil.rmtree(temp_file) + else: + os.remove(temp_file) + except Exception as e: + log.debug(f"Failed to remove temp file {temp_file}: {e}") + + # Process markers + scene_folder = os.path.join(config.config.output_data_dir, "scenes") + os.makedirs(scene_folder, exist_ok=True) + tag_folders = {} + + for marker in incorrect_markers: + scene_path = marker['scene']['files'][0]['path'] + if not scene_path: + log.error(f"Marker {marker['id']} has no scene path") + continue + try: + tag_name = marker['primary_tag']['name'] + if tag_name not in tag_folders: + tag_folders[tag_name] = os.path.join(scene_folder, tag_name) + os.makedirs(tag_folders[tag_name], exist_ok=True) + media_handler.write_scene_marker_to_file(marker, scene_path, tag_folders[tag_name]) + except Exception as e: + log.error(f"Failed to collect scene: {e}") + + # Remove incorrect tags from images + image_ids = [image['id'] for image in incorrect_images] + media_handler.remove_incorrect_tag_from_images(image_ids) + +# ----------------- Low Level Processing Functions ----------------- + +async def __tag_video_with_timing(scene: Dict[str, Any], scene_index: int) -> None: + """Tag a single video scene with timing diagnostics""" + start_time = asyncio.get_event_loop().time() + scene_id = scene.get('id', 'unknown') + + log.info(f"🎬 Starting video {scene_index + 1}: Scene {scene_id}") + + try: + await __tag_video(scene) + end_time = asyncio.get_event_loop().time() + duration = end_time - start_time + log.info(f"✅ Completed video {scene_index + 1} (Scene {scene_id}) in {duration:.2f}s") + + except Exception as e: + end_time = asyncio.get_event_loop().time() + duration = end_time - start_time + log.error(f"❌ Failed video {scene_index + 1} (Scene {scene_id}) after {duration:.2f}s: {e}") + raise + +async def __tag_video(scene: Dict[str, Any]) -> None: + """Tag a single video scene with semaphore timing instrumentation""" + scene_id = scene.get('id') + + # Pre-semaphore acquisition logging for hypothesis A (semaphore deadlock) + task_start_time = asyncio.get_event_loop().time() + acquisition_start_time = task_start_time + log.debug(f"[DEBUG_HYPOTHESIS_A] Task starting for scene {scene_id} at {task_start_time:.3f}s") + + async with semaphore: + try: + # Semaphore acquisition successful logging + acquisition_end_time = asyncio.get_event_loop().time() + acquisition_time = acquisition_end_time - acquisition_start_time + log.debug(f"[DEBUG_HYPOTHESIS_A] Semaphore acquired for scene {scene_id} after {acquisition_time:.3f}s") + + if scene_id is None: + log.error("Scene missing 'id' field") + return + + files = scene.get('files', []) + if not files: + log.error(f"Scene {scene_id} has no files") + return + + scene_file = files[0].get('path') + if scene_file is None: + log.error(f"Scene {scene_id} file has no path") + return + + # Check if scene is VR + is_vr = media_handler.is_vr_scene(scene.get('tags', [])) + + def progress_cb(p: int) -> None: + global video_progress, total_tasks + video_progress[scene_id] = p / 100.0 + total_prog = sum(video_progress.values()) / total_tasks + log.progress(total_prog) + + # Process video through VLM Engine with HTTP timing for hypothesis B + processing_start_time = asyncio.get_event_loop().time() + + # HTTP request lifecycle tracking start + log.debug(f"[DEBUG_HYPOTHESIS_B] Starting VLM processing for scene {scene_id}: {scene_file}") + + video_result = await vlm_engine.process_video_async( + scene_file, + vr_video=is_vr, + frame_interval=config.config.video_frame_interval, + threshold=config.config.video_threshold, + return_confidence=config.config.video_confidence_return, + progress_callback=progress_cb + ) + + # Extract detected tags + detected_tags = set() + for category_tags in video_result.video_tags.values(): + detected_tags.update(category_tags) + + # Post-VLM processing logging + processing_end_time = asyncio.get_event_loop().time() + processing_duration = processing_end_time - processing_start_time + log.debug(f"[DEBUG_HYPOTHESIS_B] VLM processing completed for scene {scene_id} in {processing_duration:.2f}s ({len(detected_tags)} detected tags)") + + if detected_tags: + # Clear all existing tags and markers before adding new ones + media_handler.clear_all_tags_from_video(scene) + media_handler.clear_all_markers_from_video(scene_id) + + # Add tags to scene + tag_ids = media_handler.get_tag_ids(list(detected_tags), create=True) + media_handler.add_tags_to_video(scene_id, tag_ids) + log.info(f"Added tags {list(detected_tags)} to scene {scene_id}") + + # Add markers if enabled + if config.config.create_markers: + media_handler.add_markers_to_video_from_dict(scene_id, video_result.tag_timespans) + log.info(f"Added markers to scene {scene_id}") + + # Remove VLM_TagMe tag from processed scene + media_handler.remove_tagme_tag_from_scene(scene_id) + + # Task completion logging + task_end_time = asyncio.get_event_loop().time() + total_task_time = task_end_time - task_start_time + log.debug(f"[DEBUG_HYPOTHESIS_A] Task completed for scene {scene_id} in {total_task_time:.2f}s") + + except Exception as e: + # Exception handling with detailed logging for hypothesis E + exception_time = asyncio.get_event_loop().time() + error_type = type(e).__name__ + log.debug(f"[DEBUG_HYPOTHESIS_E] Task exception for scene {scene_id}: {error_type}: {str(e)} at {exception_time:.3f}s") + + scene_id = scene.get('id', 'unknown') + log.error(f"Error processing video scene {scene_id}: {e}") + # Add error tag to failed scene if we have a valid ID + if scene_id != 'unknown': + media_handler.add_error_scene(scene_id) + +async def __find_marker_settings(scene: Dict[str, Any]) -> None: + """Find optimal marker settings for a scene""" + try: + scene_id = scene.get('id') + if scene_id is None: + log.error("Scene missing 'id' field") + return + + files = scene.get('files', []) + if not files: + log.error(f"Scene {scene_id} has no files") + return + + scene_file = files[0].get('path') + if scene_file is None: + log.error(f"Scene {scene_id} file has no path") + return + + # Get existing markers for the scene + existing_markers = media_handler.get_scene_markers(scene_id) + + # Convert markers to desired timespan format + desired_timespan_data = {} + for marker in existing_markers: + tag_name = marker['primary_tag']['name'] + desired_timespan_data[tag_name] = TimeFrame( + start=marker['seconds'], + end=marker.get('end_seconds', marker['seconds'] + 1), + total_confidence=1.0 + ) + + # Find optimal settings + optimal_settings = await vlm_engine.find_optimal_marker_settings_async( + existing_json={}, # No existing JSON data + desired_timespan_data=desired_timespan_data + ) + + # Output results + log.info(f"Optimal marker settings found for scene {scene_id}:") + log.info(json.dumps(optimal_settings, indent=2)) + + except Exception as e: + scene_id = scene.get('id', 'unknown') + log.error(f"Error finding marker settings for scene {scene_id}: {e}") + +# ----------------- Cleanup ----------------- + +async def cleanup() -> None: + """Cleanup resources""" + if vlm_engine.vlm_engine: + await vlm_engine.vlm_engine.shutdown() + +# Run main function if script is executed directly +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + log.info("Plugin interrupted by user") + sys.exit(0) + except SystemExit as e: + # Re-raise system exit with the exit code + log.debug(f"[DEBUG_EXIT_CODE] Caught SystemExit with code: {e.code}") + raise + except Exception as e: + log.error(f"Plugin failed: {e}") + sys.exit(1) + finally: + asyncio.run(cleanup()) diff --git a/plugins/AHavenVLMConnector/haven_vlm_engine.py b/plugins/AHavenVLMConnector/haven_vlm_engine.py new file mode 100644 index 00000000..7c604655 --- /dev/null +++ b/plugins/AHavenVLMConnector/haven_vlm_engine.py @@ -0,0 +1,299 @@ +""" +Haven VLM Engine Integration Module +Provides integration with the Haven VLM Engine for video and image processing +""" + +import asyncio +import logging +from typing import Any, Dict, List, Optional, Set, Union, Callable +from dataclasses import dataclass +from datetime import datetime +import json + +# Use PythonDepManager for dependency management +from vlm_engine import VLMEngine +from vlm_engine.config_models import ( + EngineConfig, + PipelineConfig, + ModelConfig, + PipelineModelConfig +) + +import haven_vlm_config as config + +# Configure logging +logging.basicConfig(level=logging.CRITICAL) +logger = logging.getLogger(__name__) + +@dataclass +class TimeFrame: + """Represents a time frame with start and end times""" + start: float + end: float + total_confidence: Optional[float] = None + + def to_json(self) -> str: + """Convert to JSON string""" + return json.dumps({ + "start": self.start, + "end": self.end, + "total_confidence": self.total_confidence + }) + + def __str__(self) -> str: + return f"TimeFrame(start={self.start}, end={self.end}, confidence={self.total_confidence})" + +@dataclass +class VideoTagInfo: + """Represents video tagging information""" + video_duration: float + video_tags: Dict[str, Set[str]] + tag_totals: Dict[str, Dict[str, float]] + tag_timespans: Dict[str, Dict[str, List[TimeFrame]]] + + @classmethod + def from_json(cls, json_data: Dict[str, Any]) -> 'VideoTagInfo': + """Create VideoTagInfo from JSON data""" + logger.debug(f"Creating VideoTagInfo from JSON: {json_data}") + + # Convert tag_timespans to TimeFrame objects + tag_timespans = {} + for category, tags in json_data.get("tag_timespans", {}).items(): + tag_timespans[category] = {} + for tag_name, timeframes in tags.items(): + tag_timespans[category][tag_name] = [ + TimeFrame( + start=tf["start"], + end=tf["end"], + total_confidence=tf.get("total_confidence") + ) for tf in timeframes + ] + + return cls( + video_duration=json_data.get("video_duration", 0.0), + video_tags=json_data.get("video_tags", {}), + tag_totals=json_data.get("tag_totals", {}), + tag_timespans=tag_timespans + ) + + def __str__(self) -> str: + return f"VideoTagInfo(duration={self.video_duration}, tags={len(self.video_tags)}, timespans={len(self.tag_timespans)})" + +class HavenVLMEngine: + """Main VLM Engine integration class""" + + def __init__(self): + self.engine: Optional[VLMEngine] = None + self.engine_config: Optional[EngineConfig] = None + self._initialized = False + + async def initialize(self) -> None: + """Initialize the VLM Engine with configuration""" + if self._initialized: + return + + try: + logger.info("Initializing Haven VLM Engine...") + + # Convert config dict to EngineConfig objects + self.engine_config = self._create_engine_config() + + # Create and initialize the engine + self.engine = VLMEngine(config=self.engine_config) + await self.engine.initialize() + + self._initialized = True + logger.info("Haven VLM Engine initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize VLM Engine: {e}") + raise + + def _create_engine_config(self) -> EngineConfig: + """Create EngineConfig from the configuration""" + vlm_config = config.config.vlm_engine_config + + # Create pipeline configs + pipelines = {} + for pipeline_name, pipeline_data in vlm_config["pipelines"].items(): + models = [ + PipelineModelConfig( + name=model["name"], + inputs=model["inputs"], + outputs=model["outputs"] + ) for model in pipeline_data["models"] + ] + + pipelines[pipeline_name] = PipelineConfig( + inputs=pipeline_data["inputs"], + output=pipeline_data["output"], + short_name=pipeline_data["short_name"], + version=pipeline_data["version"], + models=models + ) + + # Create model configs with new architectural changes + models = {} + for model_name, model_data in vlm_config["models"].items(): + if model_data["type"] == "vlm_model": + # Process multiplexer_endpoints and validate max_concurrent + multiplexer_endpoints = [] + for endpoint in model_data.get("multiplexer_endpoints", []): + # Validate that max_concurrent is present + if "max_concurrent" not in endpoint: + raise ValueError(f"Endpoint '{endpoint.get('name', 'unnamed')}' is missing required 'max_concurrent' parameter") + + multiplexer_endpoints.append({ + "base_url": endpoint["base_url"], + "api_key": endpoint.get("api_key", ""), + "name": endpoint["name"], + "weight": endpoint.get("weight", 5), + "is_fallback": endpoint.get("is_fallback", False), + "max_concurrent": endpoint["max_concurrent"] + }) + + models[model_name] = ModelConfig( + type=model_data["type"], + model_file_name=model_data["model_file_name"], + model_category=model_data["model_category"], + model_id=model_data["model_id"], + model_identifier=model_data["model_identifier"], + model_version=model_data["model_version"], + use_multiplexer=model_data.get("use_multiplexer", False), + max_concurrent_requests=model_data.get("max_concurrent_requests", 10), + instance_count=model_data.get("instance_count",1), + max_batch_size=model_data.get("max_batch_size",1), + multiplexer_endpoints=multiplexer_endpoints, + tag_list=model_data.get("tag_list", []) + ) + else: + models[model_name] = ModelConfig( + type=model_data["type"], + model_file_name=model_data["model_file_name"] + ) + + return EngineConfig( + active_ai_models=vlm_config["active_ai_models"], + pipelines=pipelines, + models=models, + category_config=vlm_config["category_config"] + ) + + async def process_video( + self, + video_path: str, + vr_video: bool = False, + frame_interval: Optional[float] = None, + threshold: Optional[float] = None, + return_confidence: Optional[bool] = None, + existing_json: Optional[Dict[str, Any]] = None, + progress_callback: Optional[Callable[[int], None]] = None + ) -> VideoTagInfo: + """Process a video using the VLM Engine""" + if not self._initialized: + await self.initialize() + + try: + logger.info(f"Processing video: {video_path}") + + # Use config defaults if not provided + frame_interval = frame_interval or config.config.video_frame_interval + threshold = threshold or config.config.video_threshold + return_confidence = return_confidence if return_confidence is not None else config.config.video_confidence_return + + # Process video through the engine + results = await self.engine.process_video( + video_path, + frame_interval=frame_interval, + progress_callback=progress_callback + ) + + logger.info(f"Video processing completed for: {video_path}") + logger.debug(f"Raw results structure: {type(results)}") + + # Extract video_tag_info from the nested structure + if isinstance(results, dict) and 'video_tag_info' in results: + video_tag_data = results['video_tag_info'] + logger.debug(f"Using video_tag_info from results: {video_tag_data.keys()}") + else: + # Fallback: assume results is already in the correct format + video_tag_data = results + logger.debug(f"Using results directly: {video_tag_data.keys() if isinstance(video_tag_data, dict) else type(video_tag_data)}") + + return VideoTagInfo.from_json(video_tag_data) + + except Exception as e: + logger.error(f"Error processing video {video_path}: {e}") + raise + + async def find_optimal_marker_settings( + self, + existing_json: Dict[str, Any], + desired_timespan_data: Dict[str, TimeFrame] + ) -> Dict[str, Any]: + """Find optimal marker settings based on existing data""" + if not self._initialized: + await self.initialize() + + try: + logger.info("Finding optimal marker settings...") + + # Convert TimeFrame objects to dict format + desired_data = {} + for key, timeframe in desired_timespan_data.items(): + desired_data[key] = { + "start": timeframe.start, + "end": timeframe.end, + "total_confidence": timeframe.total_confidence + } + + # Call the engine's optimization method + results = await self.engine.optimize_timeframe_settings( + existing_json_data=existing_json, + desired_timespan_data=desired_data + ) + + logger.info("Optimal marker settings found") + return results + + except Exception as e: + logger.error(f"Error finding optimal marker settings: {e}") + raise + + async def shutdown(self) -> None: + """Shutdown the VLM Engine""" + if self.engine and self._initialized: + try: + # VLMEngine doesn't have a shutdown method, just perform basic cleanup + logger.info("VLM Engine cleanup completed") + self._initialized = False + + except Exception as e: + logger.error(f"Error during VLM Engine cleanup: {e}") + self._initialized = False + +# Global VLM Engine instance +vlm_engine = HavenVLMEngine() + +# Convenience functions for backward compatibility +async def process_video_async( + video_path: str, + vr_video: bool = False, + frame_interval: Optional[float] = None, + threshold: Optional[float] = None, + return_confidence: Optional[bool] = None, + existing_json: Optional[Dict[str, Any]] = None, + progress_callback: Optional[Callable[[int], None]] = None +) -> VideoTagInfo: + """Process video asynchronously""" + return await vlm_engine.process_video( + video_path, vr_video, frame_interval, threshold, return_confidence, existing_json, + progress_callback=progress_callback + ) + +async def find_optimal_marker_settings_async( + existing_json: Dict[str, Any], + desired_timespan_data: Dict[str, TimeFrame] +) -> Dict[str, Any]: + """Find optimal marker settings asynchronously""" + return await vlm_engine.find_optimal_marker_settings(existing_json, desired_timespan_data) diff --git a/plugins/AHavenVLMConnector/haven_vlm_utility.py b/plugins/AHavenVLMConnector/haven_vlm_utility.py new file mode 100644 index 00000000..1a1e032f --- /dev/null +++ b/plugins/AHavenVLMConnector/haven_vlm_utility.py @@ -0,0 +1,316 @@ +""" +Haven VLM Utility Module +Utility functions for the A Haven VLM Connector plugin +""" + +import os +import json +import logging +from typing import Dict, Any, List, Optional, Union +from pathlib import Path +import yaml + +logger = logging.getLogger(__name__) + +def apply_path_mutations(path: str, mutations: Dict[str, str]) -> str: + """ + Apply path mutations for different environments + + Args: + path: Original file path + mutations: Dictionary of path mutations (e.g., {"E:": "F:", "G:": "D:"}) + + Returns: + Mutated path string + """ + if not mutations: + return path + + mutated_path = path + for old_path, new_path in mutations.items(): + if mutated_path.startswith(old_path): + mutated_path = mutated_path.replace(old_path, new_path, 1) + break + + return mutated_path + +def ensure_directory_exists(directory_path: str) -> None: + """ + Ensure a directory exists, creating it if necessary + + Args: + directory_path: Path to the directory + """ + Path(directory_path).mkdir(parents=True, exist_ok=True) + +def safe_file_operation(operation_func, *args, **kwargs) -> Optional[Any]: + """ + Safely execute a file operation with error handling + + Args: + operation_func: Function to execute + *args: Arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result of the operation or None if failed + """ + try: + return operation_func(*args, **kwargs) + except (OSError, IOError) as e: + logger.error(f"File operation failed: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error in file operation: {e}") + return None + +def load_yaml_config(config_path: str) -> Optional[Dict[str, Any]]: + """ + Load configuration from YAML file + + Args: + config_path: Path to the YAML configuration file + + Returns: + Configuration dictionary or None if failed + """ + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + logger.info(f"Configuration loaded from {config_path}") + return config + except FileNotFoundError: + logger.warning(f"Configuration file not found: {config_path}") + return None + except yaml.YAMLError as e: + logger.error(f"Error parsing YAML configuration: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error loading configuration: {e}") + return None + +def save_yaml_config(config: Dict[str, Any], config_path: str) -> bool: + """ + Save configuration to YAML file + + Args: + config: Configuration dictionary + config_path: Path to save the configuration file + + Returns: + True if successful, False otherwise + """ + try: + ensure_directory_exists(os.path.dirname(config_path)) + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(config, f, default_flow_style=False, indent=2) + logger.info(f"Configuration saved to {config_path}") + return True + except Exception as e: + logger.error(f"Error saving configuration: {e}") + return False + +def validate_file_path(file_path: str) -> bool: + """ + Validate if a file path exists and is accessible + + Args: + file_path: Path to validate + + Returns: + True if file exists and is accessible, False otherwise + """ + try: + return os.path.isfile(file_path) and os.access(file_path, os.R_OK) + except Exception: + return False + +def get_file_extension(file_path: str) -> str: + """ + Get the file extension from a file path + + Args: + file_path: Path to the file + + Returns: + File extension (including the dot) + """ + return Path(file_path).suffix.lower() + +def is_video_file(file_path: str) -> bool: + """ + Check if a file is a video file based on its extension + + Args: + file_path: Path to the file + + Returns: + True if it's a video file, False otherwise + """ + video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.m4v'} + return get_file_extension(file_path) in video_extensions + +def is_image_file(file_path: str) -> bool: + """ + Check if a file is an image file based on its extension + + Args: + file_path: Path to the file + + Returns: + True if it's an image file, False otherwise + """ + image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} + return get_file_extension(file_path) in image_extensions + +def format_duration(seconds: float) -> str: + """ + Format duration in seconds to human-readable string + + Args: + seconds: Duration in seconds + + Returns: + Formatted duration string (e.g., "1h 23m 45s") + """ + if seconds < 60: + return f"{seconds:.1f}s" + elif seconds < 3600: + minutes = int(seconds // 60) + remaining_seconds = seconds % 60 + return f"{minutes}m {remaining_seconds:.1f}s" + else: + hours = int(seconds // 3600) + remaining_minutes = int((seconds % 3600) // 60) + remaining_seconds = seconds % 60 + return f"{hours}h {remaining_minutes}m {remaining_seconds:.1f}s" + +def format_file_size(bytes_size: int) -> str: + """ + Format file size in bytes to human-readable string + + Args: + bytes_size: Size in bytes + + Returns: + Formatted size string (e.g., "1.5 MB") + """ + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_size < 1024.0: + return f"{bytes_size:.1f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.1f} PB" + +def sanitize_filename(filename: str) -> str: + """ + Sanitize a filename by removing or replacing invalid characters + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + # Replace invalid characters with underscores + invalid_chars = '<>:"/\\|?*' + for char in invalid_chars: + filename = filename.replace(char, '_') + + # Remove leading/trailing spaces and dots + filename = filename.strip(' .') + + # Ensure filename is not empty + if not filename: + filename = "unnamed" + + return filename + +def create_backup_file(file_path: str, backup_suffix: str = ".backup") -> Optional[str]: + """ + Create a backup of a file + + Args: + file_path: Path to the file to backup + backup_suffix: Suffix for the backup file + + Returns: + Path to the backup file or None if failed + """ + try: + if not os.path.exists(file_path): + logger.warning(f"File does not exist: {file_path}") + return None + + backup_path = file_path + backup_suffix + import shutil + shutil.copy2(file_path, backup_path) + logger.info(f"Backup created: {backup_path}") + return backup_path + except Exception as e: + logger.error(f"Failed to create backup: {e}") + return None + +def merge_dictionaries(dict1: Dict[str, Any], dict2: Dict[str, Any], overwrite: bool = True) -> Dict[str, Any]: + """ + Merge two dictionaries, with option to overwrite existing keys + + Args: + dict1: First dictionary + dict2: Second dictionary + overwrite: Whether to overwrite existing keys in dict1 + + Returns: + Merged dictionary + """ + result = dict1.copy() + + for key, value in dict2.items(): + if key not in result or overwrite: + result[key] = value + elif isinstance(result[key], dict) and isinstance(value, dict): + result[key] = merge_dictionaries(result[key], value, overwrite) + + return result + +def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]: + """ + Split a list into chunks of specified size + + Args: + lst: List to chunk + chunk_size: Size of each chunk + + Returns: + List of chunks + """ + return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + +def retry_operation(operation_func, max_retries: int = 3, delay: float = 1.0, *args, **kwargs) -> Optional[Any]: + """ + Retry an operation with exponential backoff + + Args: + operation_func: Function to retry + max_retries: Maximum number of retries + delay: Initial delay between retries + *args: Arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result of the operation or None if all retries failed + """ + import time + + for attempt in range(max_retries + 1): + try: + return operation_func(*args, **kwargs) + except Exception as e: + if attempt == max_retries: + logger.error(f"Operation failed after {max_retries} retries: {e}") + return None + + wait_time = delay * (2 ** attempt) + logger.warning(f"Operation failed (attempt {attempt + 1}/{max_retries + 1}), retrying in {wait_time}s: {e}") + time.sleep(wait_time) + + return None \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/requirements.txt b/plugins/AHavenVLMConnector/requirements.txt new file mode 100644 index 00000000..6d704f53 --- /dev/null +++ b/plugins/AHavenVLMConnector/requirements.txt @@ -0,0 +1,8 @@ +# Core dependencies managed by PythonDepManager +# These are automatically handled by the plugin's dependency management system +# PythonDepManager will ensure the correct versions are installed + +# Development and testing dependencies +coverage>=7.0.0 +pytest>=7.0.0 +pytest-cov>=4.0.0 diff --git a/plugins/AHavenVLMConnector/run_tests.py b/plugins/AHavenVLMConnector/run_tests.py new file mode 100644 index 00000000..bc8e0500 --- /dev/null +++ b/plugins/AHavenVLMConnector/run_tests.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Test runner for A Haven VLM Connector +Runs all unit tests with coverage reporting +""" + +import sys +import os +import subprocess +import unittest +from pathlib import Path + +def install_test_dependencies(): + """Install test dependencies if not already installed""" + test_deps = [ + 'coverage', + 'pytest', + 'pytest-cov' + ] + + for dep in test_deps: + try: + __import__(dep.replace('-', '_')) + except ImportError: + print(f"Installing {dep}...") + subprocess.check_call([sys.executable, "-m", "pip", "install", dep]) + +def run_tests_with_coverage(): + """Run tests with coverage reporting""" + # Install test dependencies + install_test_dependencies() + + # Get the directory containing this script + script_dir = Path(__file__).parent + + # Discover and run tests + loader = unittest.TestLoader() + start_dir = script_dir + suite = loader.discover(start_dir, pattern='test_*.py') + + # Run tests with coverage + import coverage + + # Start coverage measurement + cov = coverage.Coverage( + source=['haven_vlm_config.py', 'haven_vlm_engine.py', 'haven_media_handler.py', + 'haven_vlm_connector.py', 'haven_vlm_utility.py'], + omit=['*/test_*.py', '*/__pycache__/*', '*/venv/*', '*/env/*'] + ) + cov.start() + + # Run the tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Stop coverage measurement + cov.stop() + cov.save() + + # Generate coverage report + print("\n" + "="*60) + print("COVERAGE REPORT") + print("="*60) + cov.report() + + # Generate HTML coverage report + cov.html_report(directory='htmlcov') + print(f"\nHTML coverage report generated in: {script_dir}/htmlcov/index.html") + + return result.wasSuccessful() + +def run_specific_test(test_file): + """Run a specific test file""" + if not test_file.endswith('.py'): + test_file += '.py' + + test_path = Path(__file__).parent / test_file + + if not test_path.exists(): + print(f"Test file not found: {test_path}") + return False + + # Run the specific test + loader = unittest.TestLoader() + suite = loader.loadTestsFromName(test_file[:-3]) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + return result.wasSuccessful() + +def main(): + """Main entry point""" + if len(sys.argv) > 1: + # Run specific test file + test_file = sys.argv[1] + success = run_specific_test(test_file) + else: + # Run all tests with coverage + success = run_tests_with_coverage() + + if success: + print("\n✅ All tests passed!") + sys.exit(0) + else: + print("\n❌ Some tests failed!") + sys.exit(1) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_dependency_management.py b/plugins/AHavenVLMConnector/test_dependency_management.py new file mode 100644 index 00000000..190f5dbe --- /dev/null +++ b/plugins/AHavenVLMConnector/test_dependency_management.py @@ -0,0 +1,98 @@ +""" +Unit tests for dependency management functionality using PythonDepManager +""" + +import unittest +import sys +from unittest.mock import patch, MagicMock, mock_open +import tempfile +import os + +class TestPythonDepManagerIntegration(unittest.TestCase): + """Test cases for PythonDepManager integration""" + + def setUp(self): + """Set up test fixtures""" + # Mock PythonDepManager module + self.mock_python_dep_manager = MagicMock() + sys.modules['PythonDepManager'] = self.mock_python_dep_manager + + def tearDown(self): + """Clean up after tests""" + if 'PythonDepManager' in sys.modules: + del sys.modules['PythonDepManager'] + + @patch('builtins.print') + def test_dependency_import_failure(self, mock_print): + """Test dependency import failure handling""" + # Mock ensure_import to raise ImportError + self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found")) + + # Test that the error is handled gracefully + with self.assertRaises(SystemExit): + import haven_vlm_connector + + def test_error_messages(self): + """Test that appropriate error messages are displayed""" + # Mock ensure_import to raise ImportError + self.mock_python_dep_manager.ensure_import = MagicMock(side_effect=ImportError("Package not found")) + + with patch('builtins.print') as mock_print: + with self.assertRaises(SystemExit): + import haven_vlm_connector + + # Check that appropriate error messages were printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str))) + self.assertTrue(any("Please ensure PythonDepManager is installed" in msg for msg in print_calls if isinstance(msg, str))) + + +class TestDependencyManagementEdgeCases(unittest.TestCase): + """Test edge cases in dependency management""" + + def setUp(self): + """Set up test fixtures""" + self.mock_python_dep_manager = MagicMock() + sys.modules['PythonDepManager'] = self.mock_python_dep_manager + + def tearDown(self): + """Clean up after tests""" + if 'PythonDepManager' in sys.modules: + del sys.modules['PythonDepManager'] + + def test_missing_python_dep_manager(self): + """Test behavior when PythonDepManager is not available""" + # Remove PythonDepManager from sys.modules + if 'PythonDepManager' in sys.modules: + del sys.modules['PythonDepManager'] + + with patch('builtins.print') as mock_print: + with self.assertRaises(SystemExit): + import haven_vlm_connector + + # Check that appropriate error message was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Failed to import PythonDepManager" in msg for msg in print_calls if isinstance(msg, str))) + + def test_partial_dependency_failure(self): + """Test behavior when some dependencies fail to import""" + # Mock ensure_import to succeed but some imports to fail + self.mock_python_dep_manager.ensure_import = MagicMock() + + # Mock some successful imports but not all + mock_stashapi = MagicMock() + sys.modules['stashapi.log'] = mock_stashapi + sys.modules['stashapi.stashapp'] = mock_stashapi + + # Don't mock aiohttp, so it should fail + with patch('builtins.print') as mock_print: + with self.assertRaises(SystemExit): + import haven_vlm_connector + + # Check that appropriate error message was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Error during dependency management" in msg for msg in print_calls if isinstance(msg, str))) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_haven_media_handler.py b/plugins/AHavenVLMConnector/test_haven_media_handler.py new file mode 100644 index 00000000..ed81b5fc --- /dev/null +++ b/plugins/AHavenVLMConnector/test_haven_media_handler.py @@ -0,0 +1,387 @@ +""" +Unit tests for Haven Media Handler Module +Tests StashApp media operations and tag management +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +from typing import List, Dict, Any, Optional +import sys +import os + +# Add the current directory to the path to import the module +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +# Mock the dependencies before importing the module +sys.modules['PythonDepManager'] = Mock() +sys.modules['stashapi.stashapp'] = Mock() +sys.modules['stashapi.log'] = Mock() +sys.modules['haven_vlm_config'] = Mock() + +# Import the module after mocking dependencies +import haven_media_handler + + +class TestHavenMediaHandler(unittest.TestCase): + """Test cases for Haven Media Handler""" + + def setUp(self) -> None: + """Set up test fixtures""" + # Mock the stash interface + self.mock_stash = Mock() + self.mock_stash.find_tag.return_value = {"id": 1} + self.mock_stash.get_configuration.return_value = {"ui": {"vrTag": "VR"}} + self.mock_stash.stash_version.return_value = Mock() + + # Mock the log module + self.mock_log = Mock() + + # Patch the global variables + haven_media_handler.stash = self.mock_stash + haven_media_handler.log = self.mock_log + + # Mock tag IDs + haven_media_handler.vlm_errored_tag_id = 1 + haven_media_handler.vlm_tagme_tag_id = 2 + haven_media_handler.vlm_base_tag_id = 3 + haven_media_handler.vlm_tagged_tag_id = 4 + haven_media_handler.vr_tag_id = 5 + haven_media_handler.vlm_incorrect_tag_id = 6 + + def tearDown(self) -> None: + """Clean up after tests""" + # Clear any cached data + haven_media_handler.tag_id_cache.clear() + haven_media_handler.vlm_tag_ids_cache.clear() + + def test_clear_all_tags_from_video_with_tags(self) -> None: + """Test clearing all tags from a video that has tags""" + # Mock scene with tags + mock_scene = { + "id": 123, + "tags": [ + {"id": 10, "name": "Tag1"}, + {"id": 20, "name": "Tag2"}, + {"id": 30, "name": "Tag3"} + ] + } + # Call the function + haven_media_handler.clear_all_tags_from_video(mock_scene) + # Verify tags were removed + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [10, 20, 30], "mode": "REMOVE"} + }) + # Verify log message + self.mock_log.info.assert_called_once_with("Cleared 3 tags from scene 123") + + def test_clear_all_tags_from_video_no_tags(self) -> None: + """Test clearing all tags from a video that has no tags""" + # Mock scene without tags + mock_scene = {"id": 123, "tags": []} + # Call the function + haven_media_handler.clear_all_tags_from_video(mock_scene) + # Verify no update was called since there are no tags + self.mock_stash.update_scenes.assert_not_called() + # Verify no log message + self.mock_log.info.assert_not_called() + + def test_clear_all_tags_from_video_scene_without_tags_key(self) -> None: + """Test clearing all tags from a scene that doesn't have a tags key""" + # Mock scene without tags key + mock_scene = {"id": 123} + # Call the function + haven_media_handler.clear_all_tags_from_video(mock_scene) + # Verify no update was called + self.mock_stash.update_scenes.assert_not_called() + + @patch('haven_media_handler.get_scene_markers') + @patch('haven_media_handler.delete_markers') + def test_clear_all_markers_from_video_with_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None: + """Test clearing all markers from a video that has markers""" + # Mock markers + mock_markers = [ + {"id": 1, "title": "Marker1"}, + {"id": 2, "title": "Marker2"} + ] + mock_get_markers.return_value = mock_markers + + # Call the function + haven_media_handler.clear_all_markers_from_video(123) + + # Verify markers were retrieved + mock_get_markers.assert_called_once_with(123) + + # Verify markers were deleted + mock_delete_markers.assert_called_once_with(mock_markers) + + # Verify log message + self.mock_log.info.assert_called_once_with("Cleared all 2 markers from scene 123") + + @patch('haven_media_handler.get_scene_markers') + @patch('haven_media_handler.delete_markers') + def test_clear_all_markers_from_video_no_markers(self, mock_delete_markers: Mock, mock_get_markers: Mock) -> None: + """Test clearing all markers from a video that has no markers""" + # Mock no markers + mock_get_markers.return_value = [] + + # Call the function + haven_media_handler.clear_all_markers_from_video(123) + + # Verify markers were retrieved + mock_get_markers.assert_called_once_with(123) + + # Verify no deletion was called + mock_delete_markers.assert_not_called() + + # Verify no log message + self.mock_log.info.assert_not_called() + + def test_add_tags_to_video_with_tagged(self) -> None: + """Test adding tags to video with tagged flag enabled""" + # Call the function + haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=True) + + # Verify tags were added (including tagged tag) + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [10, 20, 30, 4], "mode": "ADD"} + }) + + def test_add_tags_to_video_without_tagged(self) -> None: + """Test adding tags to video with tagged flag disabled""" + # Call the function + haven_media_handler.add_tags_to_video(123, [10, 20, 30], add_tagged=False) + + # Verify tags were added (without tagged tag) + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [10, 20, 30], "mode": "ADD"} + }) + + @patch('haven_media_handler.get_vlm_tags') + def test_remove_vlm_tags_from_video(self, mock_get_vlm_tags: Mock) -> None: + """Test removing VLM tags from video""" + # Mock VLM tags + mock_get_vlm_tags.return_value = [100, 200, 300] + + # Call the function + haven_media_handler.remove_vlm_tags_from_video(123, remove_tagme=True, remove_errored=True) + + # Verify VLM tags were retrieved + mock_get_vlm_tags.assert_called_once() + + # Verify tags were removed (including tagme and errored tags) + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [100, 200, 300, 2, 1], "mode": "REMOVE"} + }) + + def test_get_tagme_scenes(self) -> None: + """Test getting scenes tagged with VLM_TagMe""" + # Mock scenes + mock_scenes = [{"id": 1}, {"id": 2}] + self.mock_stash.find_scenes.return_value = mock_scenes + + # Call the function + result = haven_media_handler.get_tagme_scenes() + + # Verify scenes were found + self.mock_stash.find_scenes.assert_called_once_with( + f={"tags": {"value": 2, "modifier": "INCLUDES"}}, + fragment="id tags {id} files {path duration fingerprint(type: \"phash\")}" + ) + + # Verify result + self.assertEqual(result, mock_scenes) + + def test_add_error_scene(self) -> None: + """Test adding error tag to a scene""" + # Call the function + haven_media_handler.add_error_scene(123) + + # Verify error tag was added + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [1], "mode": "ADD"} + }) + + def test_remove_tagme_tag_from_scene(self) -> None: + """Test removing VLM_TagMe tag from a scene""" + # Call the function + haven_media_handler.remove_tagme_tag_from_scene(123) + + # Verify tagme tag was removed + self.mock_stash.update_scenes.assert_called_once_with({ + "ids": [123], + "tag_ids": {"ids": [2], "mode": "REMOVE"} + }) + + def test_is_scene_tagged_true(self) -> None: + """Test checking if a scene is tagged (true case)""" + # Mock tags including tagged tag + tags = [ + {"id": 10, "name": "Tag1"}, + {"id": 4, "name": "VLM_Tagged"}, # This is the tagged tag + {"id": 20, "name": "Tag2"} + ] + + # Call the function + result = haven_media_handler.is_scene_tagged(tags) + + # Verify result + self.assertTrue(result) + + def test_is_scene_tagged_false(self) -> None: + """Test checking if a scene is tagged (false case)""" + # Mock tags without tagged tag + tags = [ + {"id": 10, "name": "Tag1"}, + {"id": 20, "name": "Tag2"} + ] + + # Call the function + result = haven_media_handler.is_scene_tagged(tags) + + # Verify result + self.assertFalse(result) + + def test_is_vr_scene_true(self) -> None: + """Test checking if a scene is VR (true case)""" + # Mock tags including VR tag + tags = [ + {"id": 10, "name": "Tag1"}, + {"id": 5, "name": "VR"}, # This is the VR tag + {"id": 20, "name": "Tag2"} + ] + + # Call the function + result = haven_media_handler.is_vr_scene(tags) + + # Verify result + self.assertTrue(result) + + def test_is_vr_scene_false(self) -> None: + """Test checking if a scene is VR (false case)""" + # Mock tags without VR tag + tags = [ + {"id": 10, "name": "Tag1"}, + {"id": 20, "name": "Tag2"} + ] + + # Call the function + result = haven_media_handler.is_vr_scene(tags) + + # Verify result + self.assertFalse(result) + + def test_get_tag_id_existing(self) -> None: + """Test getting tag ID for existing tag""" + # Mock existing tag + self.mock_stash.find_tag.return_value = {"id": 123, "name": "TestTag"} + + # Call the function + result = haven_media_handler.get_tag_id("TestTag", create=False) + + # Verify tag was found + self.mock_stash.find_tag.assert_called_once_with("TestTag") + + # Verify result + self.assertEqual(result, 123) + + def test_get_tag_id_not_existing_no_create(self) -> None: + """Test getting tag ID for non-existing tag without create""" + # Mock non-existing tag + self.mock_stash.find_tag.return_value = None + + # Call the function + result = haven_media_handler.get_tag_id("TestTag", create=False) + + # Verify tag was searched + self.mock_stash.find_tag.assert_called_once_with("TestTag") + + # Verify result is None + self.assertIsNone(result) + + def test_get_tag_id_create_new(self) -> None: + """Test getting tag ID for non-existing tag with create""" + # Mock non-existing tag + self.mock_stash.find_tag.return_value = None + + # Mock created tag + self.mock_stash.create_tag.return_value = {"id": 456, "name": "TestTag"} + + # Call the function + result = haven_media_handler.get_tag_id("TestTag", create=True) + + # Verify tag was searched + self.mock_stash.find_tag.assert_called_once_with("TestTag") + + # Verify tag was created + self.mock_stash.create_tag.assert_called_once_with({ + "name": "TestTag", + "ignore_auto_tag": True, + "parent_ids": [3] + }) + + # Verify result + self.assertEqual(result, 456) + + def test_get_tag_ids(self) -> None: + """Test getting multiple tag IDs""" + # Mock tag IDs + with patch('haven_media_handler.get_tag_id') as mock_get_tag_id: + mock_get_tag_id.side_effect = [10, 20, 30] + + # Call the function + result = haven_media_handler.get_tag_ids(["Tag1", "Tag2", "Tag3"], create=True) + + # Verify individual tag IDs were retrieved + self.assertEqual(mock_get_tag_id.call_count, 3) + mock_get_tag_id.assert_any_call("Tag1", True) + mock_get_tag_id.assert_any_call("Tag2", True) + mock_get_tag_id.assert_any_call("Tag3", True) + + # Verify result + self.assertEqual(result, [10, 20, 30]) + + @patch('haven_media_handler.vlm_tag_ids_cache') + def test_get_vlm_tags_from_cache(self, mock_cache: Mock) -> None: + """Test getting VLM tags from cache""" + # Mock cached tags + mock_cache.__len__.return_value = 3 + mock_cache.__iter__.return_value = iter([100, 200, 300]) + + # Call the function + result = haven_media_handler.get_vlm_tags() + + # Verify result from cache + self.assertEqual(result, [100, 200, 300]) + + def test_get_vlm_tags_from_stash(self) -> None: + """Test getting VLM tags from stash when cache is empty""" + # Mock empty cache + haven_media_handler.vlm_tag_ids_cache.clear() + + # Mock stash tags + mock_tags = [ + {"id": 100, "name": "VLM_Tag1"}, + {"id": 200, "name": "VLM_Tag2"} + ] + self.mock_stash.find_tags.return_value = mock_tags + + # Call the function + result = haven_media_handler.get_vlm_tags() + + # Verify tags were found + self.mock_stash.find_tags.assert_called_once_with( + f={"parents": {"value": 3, "modifier": "INCLUDES"}}, + fragment="id" + ) + + # Verify result + self.assertEqual(result, [100, 200]) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_haven_vlm_config.py b/plugins/AHavenVLMConnector/test_haven_vlm_config.py new file mode 100644 index 00000000..464e295f --- /dev/null +++ b/plugins/AHavenVLMConnector/test_haven_vlm_config.py @@ -0,0 +1,286 @@ +""" +Unit tests for haven_vlm_config module +""" + +import unittest +import tempfile +import os +import yaml +from unittest.mock import patch, mock_open +from dataclasses import dataclass + +import haven_vlm_config + + +class TestVLMConnectorConfig(unittest.TestCase): + """Test cases for VLMConnectorConfig dataclass""" + + def test_vlm_connector_config_creation(self): + """Test creating VLMConnectorConfig with all required fields""" + config = haven_vlm_config.VLMConnectorConfig( + vlm_engine_config={"test": "config"}, + video_frame_interval=2.0, + video_threshold=0.3, + video_confidence_return=True, + image_threshold=0.5, + image_batch_size=320, + image_confidence_return=False, + concurrent_task_limit=10, + server_timeout=3700, + vlm_base_tag_name="VLM", + vlm_tagme_tag_name="VLM_TagMe", + vlm_updateme_tag_name="VLM_UpdateMe", + vlm_tagged_tag_name="VLM_Tagged", + vlm_errored_tag_name="VLM_Errored", + vlm_incorrect_tag_name="VLM_Incorrect", + temp_image_dir="./temp_images", + output_data_dir="./output_data", + delete_incorrect_markers=True, + create_markers=True, + path_mutation={} + ) + + self.assertEqual(config.video_frame_interval, 2.0) + self.assertEqual(config.video_threshold, 0.3) + self.assertEqual(config.image_threshold, 0.5) + self.assertEqual(config.concurrent_task_limit, 10) + self.assertEqual(config.vlm_base_tag_name, "VLM") + self.assertEqual(config.temp_image_dir, "./temp_images") + + def test_vlm_connector_config_defaults(self): + """Test VLMConnectorConfig with minimal required fields""" + config = haven_vlm_config.VLMConnectorConfig( + vlm_engine_config={}, + video_frame_interval=1.0, + video_threshold=0.1, + video_confidence_return=False, + image_threshold=0.1, + image_batch_size=100, + image_confidence_return=False, + concurrent_task_limit=5, + server_timeout=1000, + vlm_base_tag_name="TEST", + vlm_tagme_tag_name="TEST_TagMe", + vlm_updateme_tag_name="TEST_UpdateMe", + vlm_tagged_tag_name="TEST_Tagged", + vlm_errored_tag_name="TEST_Errored", + vlm_incorrect_tag_name="TEST_Incorrect", + temp_image_dir="./test_temp", + output_data_dir="./test_output", + delete_incorrect_markers=False, + create_markers=False, + path_mutation={"test": "mutation"} + ) + + self.assertEqual(config.video_frame_interval, 1.0) + self.assertEqual(config.video_threshold, 0.1) + self.assertEqual(config.path_mutation, {"test": "mutation"}) + + +class TestLoadConfigFromYaml(unittest.TestCase): + """Test cases for load_config_from_yaml function""" + + def setUp(self): + """Set up test fixtures""" + self.test_config = { + "vlm_engine_config": { + "active_ai_models": ["test_model"], + "pipelines": {}, + "models": {}, + "category_config": {} + }, + "video_frame_interval": 3.0, + "video_threshold": 0.4, + "video_confidence_return": True, + "image_threshold": 0.6, + "image_batch_size": 500, + "image_confidence_return": True, + "concurrent_task_limit": 15, + "server_timeout": 5000, + "vlm_base_tag_name": "TEST_VLM", + "vlm_tagme_tag_name": "TEST_VLM_TagMe", + "vlm_updateme_tag_name": "TEST_VLM_UpdateMe", + "vlm_tagged_tag_name": "TEST_VLM_Tagged", + "vlm_errored_tag_name": "TEST_VLM_Errored", + "vlm_incorrect_tag_name": "TEST_VLM_Incorrect", + "temp_image_dir": "./test_temp_images", + "output_data_dir": "./test_output_data", + "delete_incorrect_markers": False, + "create_markers": False, + "path_mutation": {"E:": "F:"} + } + + def test_load_config_from_yaml_with_valid_file(self): + """Test loading configuration from a valid YAML file""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + yaml.dump(self.test_config, f) + config_path = f.name + + try: + config = haven_vlm_config.load_config_from_yaml(config_path) + + self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig) + self.assertEqual(config.video_frame_interval, 3.0) + self.assertEqual(config.video_threshold, 0.4) + self.assertEqual(config.image_threshold, 0.6) + self.assertEqual(config.concurrent_task_limit, 15) + self.assertEqual(config.vlm_base_tag_name, "TEST_VLM") + self.assertEqual(config.path_mutation, {"E:": "F:"}) + finally: + os.unlink(config_path) + + def test_load_config_from_yaml_with_nonexistent_file(self): + """Test loading configuration with nonexistent file path""" + config = haven_vlm_config.load_config_from_yaml("nonexistent_file.yml") + + # Should return default configuration + self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig) + self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL) + self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD) + + def test_load_config_from_yaml_with_none_path(self): + """Test loading configuration with None path""" + config = haven_vlm_config.load_config_from_yaml(None) + + # Should return default configuration + self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig) + self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL) + + def test_load_config_from_yaml_with_invalid_yaml(self): + """Test loading configuration with invalid YAML content""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + f.write("invalid: yaml: content: [") + config_path = f.name + + try: + config = haven_vlm_config.load_config_from_yaml(config_path) + + # Should return default configuration on YAML error + self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig) + self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL) + finally: + os.unlink(config_path) + + def test_load_config_from_yaml_with_file_permission_error(self): + """Test loading configuration with file permission error""" + with patch('builtins.open', side_effect=PermissionError("Permission denied")): + config = haven_vlm_config.load_config_from_yaml("test.yml") + + # Should return default configuration on file error + self.assertIsInstance(config, haven_vlm_config.VLMConnectorConfig) + self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL) + + +class TestConfigurationConstants(unittest.TestCase): + """Test cases for configuration constants""" + + def test_vlm_engine_config_structure(self): + """Test that VLM_ENGINE_CONFIG has the expected structure""" + config = haven_vlm_config.VLM_ENGINE_CONFIG + + # Check required top-level keys + self.assertIn("active_ai_models", config) + self.assertIn("pipelines", config) + self.assertIn("models", config) + self.assertIn("category_config", config) + + # Check active_ai_models is a list + self.assertIsInstance(config["active_ai_models"], list) + self.assertIn("vlm_multiplexer_model", config["active_ai_models"]) + + # Check pipelines structure + self.assertIn("video_pipeline_dynamic", config["pipelines"]) + pipeline = config["pipelines"]["video_pipeline_dynamic"] + self.assertIn("inputs", pipeline) + self.assertIn("output", pipeline) + self.assertIn("models", pipeline) + + # Check models structure + self.assertIn("vlm_multiplexer_model", config["models"]) + model = config["models"]["vlm_multiplexer_model"] + self.assertIn("type", model) + self.assertIn("multiplexer_endpoints", model) + self.assertIn("tag_list", model) + + def test_processing_settings(self): + """Test that processing settings have valid values""" + self.assertGreater(haven_vlm_config.VIDEO_FRAME_INTERVAL, 0) + self.assertGreaterEqual(haven_vlm_config.VIDEO_THRESHOLD, 0) + self.assertLessEqual(haven_vlm_config.VIDEO_THRESHOLD, 1) + self.assertGreaterEqual(haven_vlm_config.IMAGE_THRESHOLD, 0) + self.assertLessEqual(haven_vlm_config.IMAGE_THRESHOLD, 1) + self.assertGreater(haven_vlm_config.IMAGE_BATCH_SIZE, 0) + self.assertGreater(haven_vlm_config.CONCURRENT_TASK_LIMIT, 0) + self.assertGreater(haven_vlm_config.SERVER_TIMEOUT, 0) + + def test_tag_names(self): + """Test that tag names are valid strings""" + tag_names = [ + haven_vlm_config.VLM_BASE_TAG_NAME, + haven_vlm_config.VLM_TAGME_TAG_NAME, + haven_vlm_config.VLM_UPDATEME_TAG_NAME, + haven_vlm_config.VLM_TAGGED_TAG_NAME, + haven_vlm_config.VLM_ERRORED_TAG_NAME, + haven_vlm_config.VLM_INCORRECT_TAG_NAME + ] + + for tag_name in tag_names: + self.assertIsInstance(tag_name, str) + self.assertGreater(len(tag_name), 0) + + def test_directory_paths(self): + """Test that directory paths are valid strings""" + self.assertIsInstance(haven_vlm_config.TEMP_IMAGE_DIR, str) + self.assertIsInstance(haven_vlm_config.OUTPUT_DATA_DIR, str) + self.assertGreater(len(haven_vlm_config.TEMP_IMAGE_DIR), 0) + self.assertGreater(len(haven_vlm_config.OUTPUT_DATA_DIR), 0) + + def test_boolean_settings(self): + """Test that boolean settings are valid""" + self.assertIsInstance(haven_vlm_config.DELETE_INCORRECT_MARKERS, bool) + self.assertIsInstance(haven_vlm_config.CREATE_MARKERS, bool) + + def test_path_mutation(self): + """Test that path mutation is a dictionary""" + self.assertIsInstance(haven_vlm_config.PATH_MUTATION, dict) + + +class TestGlobalConfigInstance(unittest.TestCase): + """Test cases for the global config instance""" + + def test_global_config_exists(self): + """Test that the global config instance exists and is valid""" + self.assertIsInstance(haven_vlm_config.config, haven_vlm_config.VLMConnectorConfig) + + def test_global_config_has_required_attributes(self): + """Test that the global config has all required attributes""" + config = haven_vlm_config.config + + # Check that all required attributes exist + required_attrs = [ + 'vlm_engine_config', 'video_frame_interval', 'video_threshold', + 'video_confidence_return', 'image_threshold', 'image_batch_size', + 'image_confidence_return', 'concurrent_task_limit', 'server_timeout', + 'vlm_base_tag_name', 'vlm_tagme_tag_name', 'vlm_updateme_tag_name', + 'vlm_tagged_tag_name', 'vlm_errored_tag_name', 'vlm_incorrect_tag_name', + 'temp_image_dir', 'output_data_dir', 'delete_incorrect_markers', + 'create_markers', 'path_mutation' + ] + + for attr in required_attrs: + self.assertTrue(hasattr(config, attr), f"Missing attribute: {attr}") + + def test_global_config_values(self): + """Test that the global config has expected default values""" + config = haven_vlm_config.config + + self.assertEqual(config.video_frame_interval, haven_vlm_config.VIDEO_FRAME_INTERVAL) + self.assertEqual(config.video_threshold, haven_vlm_config.VIDEO_THRESHOLD) + self.assertEqual(config.image_threshold, haven_vlm_config.IMAGE_THRESHOLD) + self.assertEqual(config.concurrent_task_limit, haven_vlm_config.CONCURRENT_TASK_LIMIT) + self.assertEqual(config.vlm_base_tag_name, haven_vlm_config.VLM_BASE_TAG_NAME) + self.assertEqual(config.temp_image_dir, haven_vlm_config.TEMP_IMAGE_DIR) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_haven_vlm_connector.py b/plugins/AHavenVLMConnector/test_haven_vlm_connector.py new file mode 100644 index 00000000..c77f9122 --- /dev/null +++ b/plugins/AHavenVLMConnector/test_haven_vlm_connector.py @@ -0,0 +1,451 @@ +""" +Unit tests for haven_vlm_connector module +""" + +import unittest +import asyncio +import json +import tempfile +import os +from unittest.mock import patch, MagicMock, AsyncMock, mock_open +import sys + +# Mock the stashapi imports +sys.modules['stashapi.log'] = MagicMock() +sys.modules['stashapi.stashapp'] = MagicMock() + +# Mock the vlm_engine imports +sys.modules['vlm_engine'] = MagicMock() +sys.modules['vlm_engine.config_models'] = MagicMock() + +import haven_vlm_connector + + +class TestMainExecution(unittest.TestCase): + """Test cases for main execution functions""" + + def setUp(self): + """Set up test fixtures""" + self.sample_json_input = { + "server_connection": { + "PluginDir": "/tmp/plugin" + }, + "args": { + "mode": "tag_videos" + } + } + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.tag_videos') + @patch('haven_vlm_connector.os.chdir') + def test_run_tag_videos(self, mock_chdir, mock_tag_videos, mock_media_handler): + """Test running tag_videos mode""" + output = {} + + with patch('haven_vlm_connector.read_json_input', return_value=self.sample_json_input): + asyncio.run(haven_vlm_connector.run(self.sample_json_input, output)) + + mock_chdir.assert_called_once_with("/tmp/plugin") + mock_media_handler.initialize.assert_called_once_with(self.sample_json_input["server_connection"]) + mock_tag_videos.assert_called_once() + self.assertEqual(output["output"], "ok") + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.tag_images') + @patch('haven_vlm_connector.os.chdir') + def test_run_tag_images(self, mock_chdir, mock_tag_images, mock_media_handler): + """Test running tag_images mode""" + json_input = self.sample_json_input.copy() + json_input["args"]["mode"] = "tag_images" + output = {} + + asyncio.run(haven_vlm_connector.run(json_input, output)) + + mock_tag_images.assert_called_once() + self.assertEqual(output["output"], "ok") + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.find_marker_settings') + @patch('haven_vlm_connector.os.chdir') + def test_run_find_marker_settings(self, mock_chdir, mock_find_marker_settings, mock_media_handler): + """Test running find_marker_settings mode""" + json_input = self.sample_json_input.copy() + json_input["args"]["mode"] = "find_marker_settings" + output = {} + + asyncio.run(haven_vlm_connector.run(json_input, output)) + + mock_find_marker_settings.assert_called_once() + self.assertEqual(output["output"], "ok") + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.collect_incorrect_markers_and_images') + @patch('haven_vlm_connector.os.chdir') + def test_run_collect_incorrect_markers(self, mock_chdir, mock_collect, mock_media_handler): + """Test running collect_incorrect_markers mode""" + json_input = self.sample_json_input.copy() + json_input["args"]["mode"] = "collect_incorrect_markers" + output = {} + + asyncio.run(haven_vlm_connector.run(json_input, output)) + + mock_collect.assert_called_once() + self.assertEqual(output["output"], "ok") + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.os.chdir') + def test_run_no_mode(self, mock_chdir, mock_media_handler): + """Test running with no mode specified""" + json_input = self.sample_json_input.copy() + del json_input["args"]["mode"] + output = {} + + asyncio.run(haven_vlm_connector.run(json_input, output)) + + self.assertEqual(output["output"], "ok") + + @patch('haven_vlm_connector.media_handler') + def test_run_media_handler_initialization_error(self, mock_media_handler): + """Test handling media handler initialization error""" + mock_media_handler.initialize.side_effect = Exception("Initialization failed") + output = {} + + with self.assertRaises(Exception): + asyncio.run(haven_vlm_connector.run(self.sample_json_input, output)) + + def test_read_json_input(self): + """Test reading JSON input from stdin""" + test_input = '{"test": "data"}' + + with patch('sys.stdin.read', return_value=test_input): + result = haven_vlm_connector.read_json_input() + + self.assertEqual(result, {"test": "data"}) + + +class TestHighLevelProcessingFunctions(unittest.TestCase): + """Test cases for high-level processing functions""" + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.__tag_images') + @patch('haven_vlm_connector.asyncio.gather') + def test_tag_images_with_images(self, mock_gather, mock_tag_images, mock_media_handler): + """Test tagging images when images are available""" + mock_images = [{"id": 1}, {"id": 2}, {"id": 3}] + mock_media_handler.get_tagme_images.return_value = mock_images + + asyncio.run(haven_vlm_connector.tag_images()) + + mock_media_handler.get_tagme_images.assert_called_once() + mock_gather.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + def test_tag_images_no_images(self, mock_media_handler): + """Test tagging images when no images are available""" + mock_media_handler.get_tagme_images.return_value = [] + + asyncio.run(haven_vlm_connector.tag_images()) + + mock_media_handler.get_tagme_images.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.__tag_video') + @patch('haven_vlm_connector.asyncio.gather') + def test_tag_videos_with_scenes(self, mock_gather, mock_tag_video, mock_media_handler): + """Test tagging videos when scenes are available""" + mock_scenes = [{"id": 1}, {"id": 2}] + mock_media_handler.get_tagme_scenes.return_value = mock_scenes + + asyncio.run(haven_vlm_connector.tag_videos()) + + mock_media_handler.get_tagme_scenes.assert_called_once() + mock_gather.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + def test_tag_videos_no_scenes(self, mock_media_handler): + """Test tagging videos when no scenes are available""" + mock_media_handler.get_tagme_scenes.return_value = [] + + asyncio.run(haven_vlm_connector.tag_videos()) + + mock_media_handler.get_tagme_scenes.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.__find_marker_settings') + def test_find_marker_settings_single_scene(self, mock_find_settings, mock_media_handler): + """Test finding marker settings with single scene""" + mock_scenes = [{"id": 1}] + mock_media_handler.get_tagme_scenes.return_value = mock_scenes + + asyncio.run(haven_vlm_connector.find_marker_settings()) + + mock_media_handler.get_tagme_scenes.assert_called_once() + mock_find_settings.assert_called_once_with(mock_scenes[0]) + + @patch('haven_vlm_connector.media_handler') + def test_find_marker_settings_no_scenes(self, mock_media_handler): + """Test finding marker settings with no scenes""" + mock_media_handler.get_tagme_scenes.return_value = [] + + asyncio.run(haven_vlm_connector.find_marker_settings()) + + mock_media_handler.get_tagme_scenes.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + def test_find_marker_settings_multiple_scenes(self, mock_media_handler): + """Test finding marker settings with multiple scenes""" + mock_scenes = [{"id": 1}, {"id": 2}] + mock_media_handler.get_tagme_scenes.return_value = mock_scenes + + asyncio.run(haven_vlm_connector.find_marker_settings()) + + mock_media_handler.get_tagme_scenes.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.os.makedirs') + @patch('haven_vlm_connector.shutil.copy') + def test_collect_incorrect_markers_and_images_with_data(self, mock_copy, mock_makedirs, mock_media_handler): + """Test collecting incorrect markers and images with data""" + mock_images = [{"id": 1, "files": [{"path": "/path/to/image.jpg"}]}] + mock_markers = [{"id": 1, "scene": {"files": [{"path": "/path/to/video.mp4"}]}, "primary_tag": {"name": "test"}}] + mock_media_handler.get_incorrect_images.return_value = mock_images + mock_media_handler.get_incorrect_markers.return_value = mock_markers + mock_media_handler.get_image_paths_and_ids.return_value = (["/path/to/image.jpg"], [1], []) + + haven_vlm_connector.collect_incorrect_markers_and_images() + + mock_media_handler.get_incorrect_images.assert_called_once() + mock_media_handler.get_incorrect_markers.assert_called_once() + mock_media_handler.remove_incorrect_tag_from_images.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + def test_collect_incorrect_markers_and_images_no_data(self, mock_media_handler): + """Test collecting incorrect markers and images with no data""" + mock_media_handler.get_incorrect_images.return_value = [] + mock_media_handler.get_incorrect_markers.return_value = [] + + haven_vlm_connector.collect_incorrect_markers_and_images() + + mock_media_handler.get_incorrect_images.assert_called_once() + mock_media_handler.get_incorrect_markers.assert_called_once() + + +class TestLowLevelProcessingFunctions(unittest.TestCase): + """Test cases for low-level processing functions""" + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.semaphore') + def test_tag_images_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine): + """Test successful image tagging""" + mock_images = [{"id": 1}, {"id": 2}] + mock_media_handler.get_image_paths_and_ids.return_value = (["/path1.jpg", "/path2.jpg"], [1, 2], []) + mock_vlm_engine.process_images_async.return_value = MagicMock(result=[{"tags": ["tag1"]}, {"tags": ["tag2"]}]) + mock_media_handler.get_tag_ids.return_value = [100, 200] + + # Mock semaphore context manager + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_images(mock_images)) + + mock_media_handler.get_image_paths_and_ids.assert_called_once_with(mock_images) + mock_vlm_engine.process_images_async.assert_called_once() + mock_media_handler.remove_tagme_tags_from_images.assert_called_once() + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.semaphore') + def test_tag_images_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine): + """Test image tagging with error""" + mock_images = [{"id": 1}] + mock_vlm_engine.process_images_async.side_effect = Exception("Processing error") + + # Mock semaphore context manager + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_images(mock_images)) + + mock_media_handler.add_error_images.assert_called_once() + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.semaphore') + def test_tag_video_success(self, mock_semaphore, mock_media_handler, mock_vlm_engine): + """Test successful video tagging""" + mock_scene = { + "id": 1, + "files": [{"path": "/path/to/video.mp4"}], + "tags": [] + } + mock_vlm_engine.process_video_async.return_value = MagicMock( + video_tags={"category": ["tag1", "tag2"]}, + tag_timespans={} + ) + mock_media_handler.is_vr_scene.return_value = False + mock_media_handler.get_tag_ids.return_value = [100, 200] + + # Mock semaphore context manager + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_video(mock_scene)) + + mock_vlm_engine.process_video_async.assert_called_once() + + # Verify tags and markers were cleared before adding new ones + mock_media_handler.clear_all_tags_from_video.assert_called_once_with(1) + mock_media_handler.clear_all_markers_from_video.assert_called_once_with(1) + + mock_media_handler.add_tags_to_video.assert_called_once() + mock_media_handler.remove_tagme_tag_from_scene.assert_called_once() + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + @patch('haven_vlm_connector.semaphore') + def test_tag_video_error(self, mock_semaphore, mock_media_handler, mock_vlm_engine): + """Test video tagging with error""" + mock_scene = { + "id": 1, + "files": [{"path": "/path/to/video.mp4"}], + "tags": [] + } + mock_vlm_engine.process_video_async.side_effect = Exception("Processing error") + + # Mock semaphore context manager + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_video(mock_scene)) + + mock_media_handler.add_error_scene.assert_called_once() + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + def test_find_marker_settings_success(self, mock_media_handler, mock_vlm_engine): + """Test successful marker settings finding""" + mock_scene = { + "id": 1, + "files": [{"path": "/path/to/video.mp4"}] + } + mock_markers = [ + { + "primary_tag": {"name": "tag1"}, + "seconds": 10.0, + "end_seconds": 15.0 + } + ] + mock_media_handler.get_scene_markers.return_value = mock_markers + mock_vlm_engine.find_optimal_marker_settings_async.return_value = {"optimal": "settings"} + + asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene)) + + mock_media_handler.get_scene_markers.assert_called_once_with(1) + mock_vlm_engine.find_optimal_marker_settings_async.assert_called_once() + + @patch('haven_vlm_connector.media_handler') + def test_find_marker_settings_error(self, mock_media_handler): + """Test marker settings finding with error""" + mock_scene = { + "id": 1, + "files": [{"path": "/path/to/video.mp4"}] + } + mock_media_handler.get_scene_markers.side_effect = Exception("Marker error") + + asyncio.run(haven_vlm_connector.__find_marker_settings(mock_scene)) + + mock_media_handler.get_scene_markers.assert_called_once() + + +class TestUtilityFunctions(unittest.TestCase): + """Test cases for utility functions""" + + def test_increment_progress(self): + """Test progress increment""" + haven_vlm_connector.progress = 0.0 + haven_vlm_connector.increment = 0.1 + + haven_vlm_connector.increment_progress() + + self.assertEqual(haven_vlm_connector.progress, 0.1) + + @patch('haven_vlm_connector.vlm_engine') + async def test_cleanup(self, mock_vlm_engine): + """Test cleanup function""" + mock_vlm_engine.vlm_engine = MagicMock() + + await haven_vlm_connector.cleanup() + + mock_vlm_engine.vlm_engine.shutdown.assert_called_once() + + +class TestMainFunction(unittest.TestCase): + """Test cases for main function""" + + @patch('haven_vlm_connector.run') + @patch('haven_vlm_connector.read_json_input') + @patch('haven_vlm_connector.json.dumps') + @patch('builtins.print') + def test_main_success(self, mock_print, mock_json_dumps, mock_read_input, mock_run): + """Test successful main execution""" + mock_read_input.return_value = {"test": "data"} + mock_json_dumps.return_value = '{"output": "ok"}' + + asyncio.run(haven_vlm_connector.main()) + + mock_read_input.assert_called_once() + mock_run.assert_called_once() + mock_json_dumps.assert_called_once() + mock_print.assert_called() + + +class TestErrorHandling(unittest.TestCase): + """Test cases for error handling""" + + @patch('haven_vlm_connector.media_handler') + def test_tag_images_empty_paths(self, mock_media_handler): + """Test image tagging with empty paths""" + mock_images = [{"id": 1}] + mock_media_handler.get_image_paths_and_ids.return_value = ([], [1], []) + + # Mock semaphore context manager + with patch('haven_vlm_connector.semaphore') as mock_semaphore: + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_images(mock_images)) + + mock_media_handler.get_image_paths_and_ids.assert_called_once() + + @patch('haven_vlm_connector.vlm_engine') + @patch('haven_vlm_connector.media_handler') + def test_tag_video_no_detected_tags(self, mock_media_handler, mock_vlm_engine): + """Test video tagging with no detected tags""" + mock_scene = { + "id": 1, + "files": [{"path": "/path/to/video.mp4"}], + "tags": [] + } + mock_vlm_engine.process_video_async.return_value = MagicMock( + video_tags={}, + tag_timespans={} + ) + mock_media_handler.is_vr_scene.return_value = False + + # Mock semaphore context manager + with patch('haven_vlm_connector.semaphore') as mock_semaphore: + mock_semaphore.__aenter__ = AsyncMock() + mock_semaphore.__aexit__ = AsyncMock() + + asyncio.run(haven_vlm_connector.__tag_video(mock_scene)) + + # Verify clearing functions are NOT called when no tags are detected + mock_media_handler.clear_all_tags_from_video.assert_not_called() + mock_media_handler.clear_all_markers_from_video.assert_not_called() + + mock_media_handler.remove_tagme_tag_from_scene.assert_called_once() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_haven_vlm_engine.py b/plugins/AHavenVLMConnector/test_haven_vlm_engine.py new file mode 100644 index 00000000..65adcc55 --- /dev/null +++ b/plugins/AHavenVLMConnector/test_haven_vlm_engine.py @@ -0,0 +1,544 @@ +""" +Unit tests for haven_vlm_engine module +""" + +import unittest +import asyncio +import json +import tempfile +import os +from unittest.mock import patch, MagicMock, AsyncMock, mock_open +import sys + +# Mock the vlm_engine imports +sys.modules['vlm_engine'] = MagicMock() +sys.modules['vlm_engine.config_models'] = MagicMock() + +import haven_vlm_engine + + +class TestTimeFrame(unittest.TestCase): + """Test cases for TimeFrame dataclass""" + + def test_timeframe_creation(self): + """Test creating TimeFrame with all parameters""" + timeframe = haven_vlm_engine.TimeFrame( + start=10.0, + end=15.0, + total_confidence=0.85 + ) + + self.assertEqual(timeframe.start, 10.0) + self.assertEqual(timeframe.end, 15.0) + self.assertEqual(timeframe.total_confidence, 0.85) + + def test_timeframe_creation_without_confidence(self): + """Test creating TimeFrame without confidence""" + timeframe = haven_vlm_engine.TimeFrame( + start=10.0, + end=15.0 + ) + + self.assertEqual(timeframe.start, 10.0) + self.assertEqual(timeframe.end, 15.0) + self.assertIsNone(timeframe.total_confidence) + + def test_timeframe_to_json(self): + """Test TimeFrame to_json method""" + timeframe = haven_vlm_engine.TimeFrame( + start=10.0, + end=15.0, + total_confidence=0.85 + ) + + json_str = timeframe.to_json() + json_data = json.loads(json_str) + + self.assertEqual(json_data["start"], 10.0) + self.assertEqual(json_data["end"], 15.0) + self.assertEqual(json_data["total_confidence"], 0.85) + + def test_timeframe_to_json_without_confidence(self): + """Test TimeFrame to_json method without confidence""" + timeframe = haven_vlm_engine.TimeFrame( + start=10.0, + end=15.0 + ) + + json_str = timeframe.to_json() + json_data = json.loads(json_str) + + self.assertEqual(json_data["start"], 10.0) + self.assertEqual(json_data["end"], 15.0) + self.assertIsNone(json_data["total_confidence"]) + + def test_timeframe_str(self): + """Test TimeFrame string representation""" + timeframe = haven_vlm_engine.TimeFrame( + start=10.0, + end=15.0, + total_confidence=0.85 + ) + + str_repr = str(timeframe) + self.assertIn("10.0", str_repr) + self.assertIn("15.0", str_repr) + self.assertIn("0.85", str_repr) + + +class TestVideoTagInfo(unittest.TestCase): + """Test cases for VideoTagInfo dataclass""" + + def test_videotaginfo_creation(self): + """Test creating VideoTagInfo with all parameters""" + video_tags = {"category1": {"tag1", "tag2"}} + tag_totals = {"tag1": {"total": 0.8}} + tag_timespans = {"category1": {"tag1": [haven_vlm_engine.TimeFrame(10.0, 15.0)]}} + + video_info = haven_vlm_engine.VideoTagInfo( + video_duration=120.0, + video_tags=video_tags, + tag_totals=tag_totals, + tag_timespans=tag_timespans + ) + + self.assertEqual(video_info.video_duration, 120.0) + self.assertEqual(video_info.video_tags, video_tags) + self.assertEqual(video_info.tag_totals, tag_totals) + self.assertEqual(video_info.tag_timespans, tag_timespans) + + def test_videotaginfo_from_json(self): + """Test creating VideoTagInfo from JSON data""" + json_data = { + "video_duration": 120.0, + "video_tags": {"category1": ["tag1", "tag2"]}, + "tag_totals": {"tag1": {"total": 0.8}}, + "tag_timespans": { + "category1": { + "tag1": [ + {"start": 10.0, "end": 15.0, "total_confidence": 0.85} + ] + } + } + } + + video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data) + + self.assertEqual(video_info.video_duration, 120.0) + self.assertEqual(video_info.video_tags, {"category1": ["tag1", "tag2"]}) + self.assertEqual(video_info.tag_totals, {"tag1": {"total": 0.8}}) + + # Check that tag_timespans contains TimeFrame objects + self.assertIn("category1", video_info.tag_timespans) + self.assertIn("tag1", video_info.tag_timespans["category1"]) + self.assertIsInstance(video_info.tag_timespans["category1"]["tag1"][0], haven_vlm_engine.TimeFrame) + + def test_videotaginfo_from_json_without_confidence(self): + """Test creating VideoTagInfo from JSON data without confidence""" + json_data = { + "video_duration": 120.0, + "video_tags": {"category1": ["tag1"]}, + "tag_totals": {"tag1": {"total": 0.8}}, + "tag_timespans": { + "category1": { + "tag1": [ + {"start": 10.0, "end": 15.0} + ] + } + } + } + + video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data) + + timeframe = video_info.tag_timespans["category1"]["tag1"][0] + self.assertEqual(timeframe.start, 10.0) + self.assertEqual(timeframe.end, 15.0) + self.assertIsNone(timeframe.total_confidence) + + def test_videotaginfo_from_json_empty_timespans(self): + """Test creating VideoTagInfo from JSON data with empty timespans""" + json_data = { + "video_duration": 120.0, + "video_tags": {"category1": ["tag1"]}, + "tag_totals": {"tag1": {"total": 0.8}}, + "tag_timespans": {} + } + + video_info = haven_vlm_engine.VideoTagInfo.from_json(json_data) + + self.assertEqual(video_info.video_duration, 120.0) + self.assertEqual(video_info.tag_timespans, {}) + + def test_videotaginfo_str(self): + """Test VideoTagInfo string representation""" + video_info = haven_vlm_engine.VideoTagInfo( + video_duration=120.0, + video_tags={"category1": {"tag1"}}, + tag_totals={"tag1": {"total": 0.8}}, + tag_timespans={"category1": {"tag1": []}} + ) + + str_repr = str(video_info) + self.assertIn("120.0", str_repr) + self.assertIn("1", str_repr) # number of tags + self.assertIn("1", str_repr) # number of timespans + + +class TestImageResult(unittest.TestCase): + """Test cases for ImageResult dataclass""" + + def test_imageresult_creation(self): + """Test creating ImageResult with valid data""" + result_data = [{"tags": ["tag1"], "confidence": 0.8}] + image_result = haven_vlm_engine.ImageResult(result=result_data) + + self.assertEqual(image_result.result, result_data) + + def test_imageresult_creation_empty_list(self): + """Test creating ImageResult with empty list""" + with self.assertRaises(ValueError): + haven_vlm_engine.ImageResult(result=[]) + + def test_imageresult_creation_none_result(self): + """Test creating ImageResult with None result""" + with self.assertRaises(ValueError): + haven_vlm_engine.ImageResult(result=None) + + +class TestHavenVLMEngine(unittest.TestCase): + """Test cases for HavenVLMEngine class""" + + def setUp(self): + """Set up test fixtures""" + self.engine = haven_vlm_engine.HavenVLMEngine() + + def test_engine_initialization(self): + """Test engine initialization""" + self.assertIsNone(self.engine.engine) + self.assertIsNone(self.engine.engine_config) + self.assertFalse(self.engine._initialized) + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_initialize_success(self, mock_vlm_engine_class, mock_config): + """Test successful engine initialization""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.vlm_engine_config = {"test": "config"} + + await self.engine.initialize() + + self.assertTrue(self.engine._initialized) + mock_vlm_engine_class.assert_called_once() + mock_engine_instance.initialize.assert_called_once() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_initialize_already_initialized(self, mock_vlm_engine_class, mock_config): + """Test initialization when already initialized""" + self.engine._initialized = True + + await self.engine.initialize() + + # Should not call VLMEngine constructor again + mock_vlm_engine_class.assert_not_called() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_initialize_error(self, mock_vlm_engine_class, mock_config): + """Test initialization with error""" + mock_vlm_engine_class.side_effect = Exception("Initialization failed") + mock_config.config.vlm_engine_config = {"test": "config"} + + with self.assertRaises(Exception): + await self.engine.initialize() + + self.assertFalse(self.engine._initialized) + + @patch('haven_vlm_engine.config') + def test_create_engine_config(self, mock_config): + """Test creating engine configuration""" + mock_config.config.vlm_engine_config = { + "active_ai_models": ["model1"], + "pipelines": { + "pipeline1": { + "inputs": ["input1"], + "output": "output1", + "short_name": "short1", + "version": 1.0, + "models": [ + { + "name": "model1", + "inputs": ["input1"], + "outputs": "output1" + } + ] + } + }, + "models": { + "model1": { + "type": "vlm_model", + "model_file_name": "model1.py", + "model_category": "test", + "model_id": "test_model", + "model_identifier": 123, + "model_version": "1.0", + "use_multiplexer": True, + "max_concurrent_requests": 10, + "connection_pool_size": 20, + "multiplexer_endpoints": [], + "tag_list": ["tag1"] + } + }, + "category_config": {"test": {}} + } + + config = self.engine._create_engine_config() + + self.assertIsNotNone(config) + # Note: We can't easily test the exact structure without the actual VLM Engine classes + # but we can verify the method doesn't raise exceptions + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_process_video_success(self, mock_vlm_engine_class, mock_config): + """Test successful video processing""" + # Setup mocks + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.video_frame_interval = 2.0 + mock_config.config.video_threshold = 0.3 + mock_config.config.video_confidence_return = True + + # Mock the engine's process_video method + mock_engine_instance.process_video.return_value = { + "video_duration": 120.0, + "video_tags": {"category1": ["tag1"]}, + "tag_totals": {"tag1": {"total": 0.8}}, + "tag_timespans": {} + } + + # Initialize engine + await self.engine.initialize() + + # Process video + result = await self.engine.process_video("/path/to/video.mp4") + + self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo) + mock_engine_instance.process_video.assert_called_once() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_process_video_not_initialized(self, mock_vlm_engine_class, mock_config): + """Test video processing when not initialized""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.video_frame_interval = 2.0 + mock_config.config.video_threshold = 0.3 + mock_config.config.video_confidence_return = True + + mock_engine_instance.process_video.return_value = { + "video_duration": 120.0, + "video_tags": {"category1": ["tag1"]}, + "tag_totals": {"tag1": {"total": 0.8}}, + "tag_timespans": {} + } + + # Process video without explicit initialization + result = await self.engine.process_video("/path/to/video.mp4") + + self.assertIsInstance(result, haven_vlm_engine.VideoTagInfo) + mock_engine_instance.initialize.assert_called_once() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_process_video_error(self, mock_vlm_engine_class, mock_config): + """Test video processing with error""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.video_frame_interval = 2.0 + mock_config.config.video_threshold = 0.3 + mock_config.config.video_confidence_return = True + + mock_engine_instance.process_video.side_effect = Exception("Processing failed") + + await self.engine.initialize() + + with self.assertRaises(Exception): + await self.engine.process_video("/path/to/video.mp4") + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_process_images_success(self, mock_vlm_engine_class, mock_config): + """Test successful image processing""" + # Setup mocks + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.image_threshold = 0.5 + mock_config.config.image_confidence_return = False + + # Mock the engine's process_images method + mock_engine_instance.process_images.return_value = [ + {"tags": ["tag1"], "confidence": 0.8} + ] + + # Initialize engine + await self.engine.initialize() + + # Process images + result = await self.engine.process_images(["/path/to/image1.jpg"]) + + self.assertIsInstance(result, haven_vlm_engine.ImageResult) + mock_engine_instance.process_images.assert_called_once() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_process_images_error(self, mock_vlm_engine_class, mock_config): + """Test image processing with error""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_config.config.image_threshold = 0.5 + mock_config.config.image_confidence_return = False + + mock_engine_instance.process_images.side_effect = Exception("Processing failed") + + await self.engine.initialize() + + with self.assertRaises(Exception): + await self.engine.process_images(["/path/to/image1.jpg"]) + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_find_optimal_marker_settings_success(self, mock_vlm_engine_class, mock_config): + """Test successful marker settings optimization""" + # Setup mocks + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_engine_instance.optimize_timeframe_settings.return_value = {"optimal": "settings"} + + # Initialize engine + await self.engine.initialize() + + # Test data + existing_json = {"existing": "data"} + desired_timespan_data = { + "tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8) + } + + # Find optimal settings + result = await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data) + + self.assertEqual(result, {"optimal": "settings"}) + mock_engine_instance.optimize_timeframe_settings.assert_called_once() + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_find_optimal_marker_settings_error(self, mock_vlm_engine_class, mock_config): + """Test marker settings optimization with error""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_engine_instance.optimize_timeframe_settings.side_effect = Exception("Optimization failed") + + await self.engine.initialize() + + existing_json = {"existing": "data"} + desired_timespan_data = { + "tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8) + } + + with self.assertRaises(Exception): + await self.engine.find_optimal_marker_settings(existing_json, desired_timespan_data) + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_shutdown_success(self, mock_vlm_engine_class, mock_config): + """Test successful engine shutdown""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + + await self.engine.initialize() + await self.engine.shutdown() + + mock_engine_instance.shutdown.assert_called_once() + self.assertFalse(self.engine._initialized) + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_shutdown_not_initialized(self, mock_vlm_engine_class, mock_config): + """Test shutdown when not initialized""" + await self.engine.shutdown() + + # Should not raise any exceptions + self.assertFalse(self.engine._initialized) + + @patch('haven_vlm_engine.config') + @patch('haven_vlm_engine.VLMEngine') + async def test_shutdown_error(self, mock_vlm_engine_class, mock_config): + """Test shutdown with error""" + mock_engine_instance = MagicMock() + mock_vlm_engine_class.return_value = mock_engine_instance + mock_engine_instance.shutdown.side_effect = Exception("Shutdown failed") + + await self.engine.initialize() + await self.engine.shutdown() + + # Should handle the error gracefully + self.assertFalse(self.engine._initialized) + + +class TestConvenienceFunctions(unittest.TestCase): + """Test cases for convenience functions""" + + @patch('haven_vlm_engine.vlm_engine') + async def test_process_video_async(self, mock_vlm_engine): + """Test process_video_async convenience function""" + mock_vlm_engine.process_video.return_value = MagicMock() + + result = await haven_vlm_engine.process_video_async("/path/to/video.mp4") + + mock_vlm_engine.process_video.assert_called_once() + self.assertEqual(result, mock_vlm_engine.process_video.return_value) + + @patch('haven_vlm_engine.vlm_engine') + async def test_process_images_async(self, mock_vlm_engine): + """Test process_images_async convenience function""" + mock_vlm_engine.process_images.return_value = MagicMock() + + result = await haven_vlm_engine.process_images_async(["/path/to/image.jpg"]) + + mock_vlm_engine.process_images.assert_called_once() + self.assertEqual(result, mock_vlm_engine.process_images.return_value) + + @patch('haven_vlm_engine.vlm_engine') + async def test_find_optimal_marker_settings_async(self, mock_vlm_engine): + """Test find_optimal_marker_settings_async convenience function""" + mock_vlm_engine.find_optimal_marker_settings.return_value = {"optimal": "settings"} + + existing_json = {"existing": "data"} + desired_timespan_data = { + "tag1": haven_vlm_engine.TimeFrame(10.0, 15.0, 0.8) + } + + result = await haven_vlm_engine.find_optimal_marker_settings_async(existing_json, desired_timespan_data) + + mock_vlm_engine.find_optimal_marker_settings.assert_called_once() + self.assertEqual(result, {"optimal": "settings"}) + + +class TestGlobalVLMEngineInstance(unittest.TestCase): + """Test cases for global VLM engine instance""" + + def test_global_vlm_engine_exists(self): + """Test that global VLM engine instance exists""" + self.assertIsInstance(haven_vlm_engine.vlm_engine, haven_vlm_engine.HavenVLMEngine) + + def test_global_vlm_engine_is_singleton(self): + """Test that global VLM engine is a singleton""" + engine1 = haven_vlm_engine.vlm_engine + engine2 = haven_vlm_engine.vlm_engine + + self.assertIs(engine1, engine2) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/plugins/AHavenVLMConnector/test_haven_vlm_utility.py b/plugins/AHavenVLMConnector/test_haven_vlm_utility.py new file mode 100644 index 00000000..be3d6263 --- /dev/null +++ b/plugins/AHavenVLMConnector/test_haven_vlm_utility.py @@ -0,0 +1,604 @@ +""" +Unit tests for haven_vlm_utility module +""" + +import unittest +import tempfile +import os +import shutil +import time +from unittest.mock import patch, mock_open, MagicMock +import yaml + +import haven_vlm_utility + + +class TestPathMutations(unittest.TestCase): + """Test cases for path mutation functions""" + + def test_apply_path_mutations_with_mutations(self): + """Test applying path mutations with valid mutations""" + mutations = {"E:": "F:", "G:": "D:"} + path = "E:\\videos\\test.mp4" + + result = haven_vlm_utility.apply_path_mutations(path, mutations) + + self.assertEqual(result, "F:\\videos\\test.mp4") + + def test_apply_path_mutations_without_mutations(self): + """Test applying path mutations with empty mutations""" + mutations = {} + path = "E:\\videos\\test.mp4" + + result = haven_vlm_utility.apply_path_mutations(path, mutations) + + self.assertEqual(result, path) + + def test_apply_path_mutations_with_none_mutations(self): + """Test applying path mutations with None mutations""" + mutations = None + path = "E:\\videos\\test.mp4" + + result = haven_vlm_utility.apply_path_mutations(path, mutations) + + self.assertEqual(result, path) + + def test_apply_path_mutations_no_match(self): + """Test applying path mutations when no mutation matches""" + mutations = {"E:": "F:", "G:": "D:"} + path = "C:\\videos\\test.mp4" + + result = haven_vlm_utility.apply_path_mutations(path, mutations) + + self.assertEqual(result, path) + + def test_apply_path_mutations_multiple_matches(self): + """Test applying path mutations with multiple possible matches""" + mutations = {"E:": "F:", "E:\\videos": "F:\\movies"} + path = "E:\\videos\\test.mp4" + + result = haven_vlm_utility.apply_path_mutations(path, mutations) + + # Should use the first match + self.assertEqual(result, "F:\\videos\\test.mp4") + + +class TestDirectoryOperations(unittest.TestCase): + """Test cases for directory operations""" + + def test_ensure_directory_exists_new_directory(self): + """Test creating a new directory""" + with tempfile.TemporaryDirectory() as temp_dir: + new_dir = os.path.join(temp_dir, "test_subdir") + + haven_vlm_utility.ensure_directory_exists(new_dir) + + self.assertTrue(os.path.exists(new_dir)) + self.assertTrue(os.path.isdir(new_dir)) + + def test_ensure_directory_exists_existing_directory(self): + """Test ensuring directory exists when it already exists""" + with tempfile.TemporaryDirectory() as temp_dir: + haven_vlm_utility.ensure_directory_exists(temp_dir) + + self.assertTrue(os.path.exists(temp_dir)) + self.assertTrue(os.path.isdir(temp_dir)) + + def test_ensure_directory_exists_nested_directories(self): + """Test creating nested directories""" + with tempfile.TemporaryDirectory() as temp_dir: + nested_dir = os.path.join(temp_dir, "level1", "level2", "level3") + + haven_vlm_utility.ensure_directory_exists(nested_dir) + + self.assertTrue(os.path.exists(nested_dir)) + self.assertTrue(os.path.isdir(nested_dir)) + + +class TestSafeFileOperations(unittest.TestCase): + """Test cases for safe file operations""" + + def test_safe_file_operation_success(self): + """Test successful file operation""" + def test_func(a, b, c=10): + return a + b + c + + result = haven_vlm_utility.safe_file_operation(test_func, 1, 2, c=5) + + self.assertEqual(result, 8) + + def test_safe_file_operation_os_error(self): + """Test file operation with OSError""" + def test_func(): + raise OSError("File not found") + + result = haven_vlm_utility.safe_file_operation(test_func) + + self.assertIsNone(result) + + def test_safe_file_operation_io_error(self): + """Test file operation with IOError""" + def test_func(): + raise IOError("Permission denied") + + result = haven_vlm_utility.safe_file_operation(test_func) + + self.assertIsNone(result) + + def test_safe_file_operation_unexpected_error(self): + """Test file operation with unexpected error""" + def test_func(): + raise ValueError("Unexpected error") + + result = haven_vlm_utility.safe_file_operation(test_func) + + self.assertIsNone(result) + + +class TestYamlConfigOperations(unittest.TestCase): + """Test cases for YAML configuration operations""" + + def setUp(self): + """Set up test fixtures""" + self.test_config = { + "video_frame_interval": 2.0, + "video_threshold": 0.3, + "image_threshold": 0.5, + "endpoints": [ + {"url": "http://localhost:1234", "weight": 5}, + {"url": "https://cloud.example.com", "weight": 1} + ] + } + + def test_load_yaml_config_success(self): + """Test successfully loading YAML configuration""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + yaml.dump(self.test_config, f) + config_path = f.name + + try: + result = haven_vlm_utility.load_yaml_config(config_path) + + self.assertEqual(result, self.test_config) + finally: + os.unlink(config_path) + + def test_load_yaml_config_file_not_found(self): + """Test loading YAML configuration from nonexistent file""" + result = haven_vlm_utility.load_yaml_config("nonexistent_file.yml") + + self.assertIsNone(result) + + def test_load_yaml_config_invalid_yaml(self): + """Test loading YAML configuration with invalid YAML""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + f.write("invalid: yaml: content: [") + config_path = f.name + + try: + result = haven_vlm_utility.load_yaml_config(config_path) + + self.assertIsNone(result) + finally: + os.unlink(config_path) + + def test_load_yaml_config_permission_error(self): + """Test loading YAML configuration with permission error""" + with patch('builtins.open', side_effect=PermissionError("Permission denied")): + result = haven_vlm_utility.load_yaml_config("test.yml") + + self.assertIsNone(result) + + def test_save_yaml_config_success(self): + """Test successfully saving YAML configuration""" + with tempfile.TemporaryDirectory() as temp_dir: + config_path = os.path.join(temp_dir, "test_config.yml") + + result = haven_vlm_utility.save_yaml_config(self.test_config, config_path) + + self.assertTrue(result) + self.assertTrue(os.path.exists(config_path)) + + # Verify the saved content + with open(config_path, 'r') as f: + loaded_config = yaml.safe_load(f) + + self.assertEqual(loaded_config, self.test_config) + + def test_save_yaml_config_with_nested_directories(self): + """Test saving YAML configuration to nested directory""" + with tempfile.TemporaryDirectory() as temp_dir: + config_path = os.path.join(temp_dir, "nested", "dir", "test_config.yml") + + result = haven_vlm_utility.save_yaml_config(self.test_config, config_path) + + self.assertTrue(result) + self.assertTrue(os.path.exists(config_path)) + + def test_save_yaml_config_permission_error(self): + """Test saving YAML configuration with permission error""" + with patch('builtins.open', side_effect=PermissionError("Permission denied")): + result = haven_vlm_utility.save_yaml_config(self.test_config, "test.yml") + + self.assertFalse(result) + + +class TestFileValidation(unittest.TestCase): + """Test cases for file validation functions""" + + def test_validate_file_path_existing_file(self): + """Test validating an existing file path""" + with tempfile.NamedTemporaryFile(delete=False) as f: + file_path = f.name + + try: + result = haven_vlm_utility.validate_file_path(file_path) + self.assertTrue(result) + finally: + os.unlink(file_path) + + def test_validate_file_path_nonexistent_file(self): + """Test validating a nonexistent file path""" + result = haven_vlm_utility.validate_file_path("nonexistent_file.txt") + self.assertFalse(result) + + def test_validate_file_path_directory(self): + """Test validating a directory path""" + with tempfile.TemporaryDirectory() as temp_dir: + result = haven_vlm_utility.validate_file_path(temp_dir) + self.assertFalse(result) + + def test_validate_file_path_permission_error(self): + """Test validating file path with permission error""" + with patch('os.path.isfile', side_effect=OSError("Permission denied")): + result = haven_vlm_utility.validate_file_path("test.txt") + self.assertFalse(result) + + +class TestFileExtensionFunctions(unittest.TestCase): + """Test cases for file extension functions""" + + def test_get_file_extension_with_extension(self): + """Test getting file extension from file with extension""" + result = haven_vlm_utility.get_file_extension("test.mp4") + self.assertEqual(result, ".mp4") + + def test_get_file_extension_without_extension(self): + """Test getting file extension from file without extension""" + result = haven_vlm_utility.get_file_extension("test") + self.assertEqual(result, "") + + def test_get_file_extension_multiple_dots(self): + """Test getting file extension from file with multiple dots""" + result = haven_vlm_utility.get_file_extension("test.backup.mp4") + self.assertEqual(result, ".mp4") + + def test_get_file_extension_uppercase(self): + """Test getting file extension from file with uppercase extension""" + result = haven_vlm_utility.get_file_extension("test.MP4") + self.assertEqual(result, ".mp4") + + def test_is_video_file_valid_extensions(self): + """Test video file detection with valid extensions""" + video_files = ["test.mp4", "test.avi", "test.mkv", "test.mov", "test.wmv", "test.flv", "test.webm", "test.m4v"] + + for video_file in video_files: + result = haven_vlm_utility.is_video_file(video_file) + self.assertTrue(result, f"Failed for {video_file}") + + def test_is_video_file_invalid_extensions(self): + """Test video file detection with invalid extensions""" + non_video_files = ["test.jpg", "test.txt", "test.pdf", "test.exe"] + + for non_video_file in non_video_files: + result = haven_vlm_utility.is_video_file(non_video_file) + self.assertFalse(result, f"Failed for {non_video_file}") + + def test_is_image_file_valid_extensions(self): + """Test image file detection with valid extensions""" + image_files = ["test.jpg", "test.jpeg", "test.png", "test.gif", "test.bmp", "test.tiff", "test.webp"] + + for image_file in image_files: + result = haven_vlm_utility.is_image_file(image_file) + self.assertTrue(result, f"Failed for {image_file}") + + def test_is_image_file_invalid_extensions(self): + """Test image file detection with invalid extensions""" + non_image_files = ["test.mp4", "test.txt", "test.pdf", "test.exe"] + + for non_image_file in non_image_files: + result = haven_vlm_utility.is_image_file(non_image_file) + self.assertFalse(result, f"Failed for {non_image_file}") + + +class TestFormattingFunctions(unittest.TestCase): + """Test cases for formatting functions""" + + def test_format_duration_seconds(self): + """Test formatting duration in seconds""" + result = haven_vlm_utility.format_duration(45.5) + self.assertEqual(result, "45.5s") + + def test_format_duration_minutes(self): + """Test formatting duration in minutes""" + result = haven_vlm_utility.format_duration(125.3) + self.assertEqual(result, "2m 5.3s") + + def test_format_duration_hours(self): + """Test formatting duration in hours""" + result = haven_vlm_utility.format_duration(7325.7) + self.assertEqual(result, "2h 2m 5.7s") + + def test_format_duration_zero(self): + """Test formatting zero duration""" + result = haven_vlm_utility.format_duration(0) + self.assertEqual(result, "0.0s") + + def test_format_file_size_bytes(self): + """Test formatting file size in bytes""" + result = haven_vlm_utility.format_file_size(512) + self.assertEqual(result, "512.0 B") + + def test_format_file_size_kilobytes(self): + """Test formatting file size in kilobytes""" + result = haven_vlm_utility.format_file_size(1536) + self.assertEqual(result, "1.5 KB") + + def test_format_file_size_megabytes(self): + """Test formatting file size in megabytes""" + result = haven_vlm_utility.format_file_size(1572864) + self.assertEqual(result, "1.5 MB") + + def test_format_file_size_gigabytes(self): + """Test formatting file size in gigabytes""" + result = haven_vlm_utility.format_file_size(1610612736) + self.assertEqual(result, "1.5 GB") + + def test_format_file_size_zero(self): + """Test formatting zero file size""" + result = haven_vlm_utility.format_file_size(0) + self.assertEqual(result, "0.0 B") + + +class TestSanitizationFunctions(unittest.TestCase): + """Test cases for sanitization functions""" + + def test_sanitize_filename_valid(self): + """Test sanitizing a valid filename""" + result = haven_vlm_utility.sanitize_filename("valid_filename.txt") + self.assertEqual(result, "valid_filename.txt") + + def test_sanitize_filename_invalid_chars(self): + """Test sanitizing filename with invalid characters""" + result = haven_vlm_utility.sanitize_filename("file:with/invalid\\chars|?*") + self.assertEqual(result, "file_name__with_invalid_chars___") + + def test_sanitize_filename_leading_trailing_spaces(self): + """Test sanitizing filename with leading/trailing spaces""" + result = haven_vlm_utility.sanitize_filename(" filename.txt ") + self.assertEqual(result, "filename.txt") + + def test_sanitize_filename_leading_trailing_dots(self): + """Test sanitizing filename with leading/trailing dots""" + result = haven_vlm_utility.sanitize_filename("...filename.txt...") + self.assertEqual(result, "filename.txt") + + def test_sanitize_filename_empty(self): + """Test sanitizing empty filename""" + result = haven_vlm_utility.sanitize_filename("") + self.assertEqual(result, "unnamed") + + def test_sanitize_filename_only_spaces(self): + """Test sanitizing filename with only spaces""" + result = haven_vlm_utility.sanitize_filename(" ") + self.assertEqual(result, "unnamed") + + +class TestBackupFunctions(unittest.TestCase): + """Test cases for backup functions""" + + def test_create_backup_file_success(self): + """Test successfully creating a backup file""" + with tempfile.NamedTemporaryFile(delete=False) as f: + original_file = f.name + f.write(b"test content") + + try: + result = haven_vlm_utility.create_backup_file(original_file) + + self.assertIsNotNone(result) + self.assertTrue(os.path.exists(result)) + self.assertTrue(result.endswith(".backup")) + + # Verify backup content + with open(result, 'rb') as f: + content = f.read() + self.assertEqual(content, b"test content") + + # Clean up backup + os.unlink(result) + finally: + os.unlink(original_file) + + def test_create_backup_file_custom_suffix(self): + """Test creating backup file with custom suffix""" + with tempfile.NamedTemporaryFile(delete=False) as f: + original_file = f.name + f.write(b"test content") + + try: + result = haven_vlm_utility.create_backup_file(original_file, ".custom") + + self.assertIsNotNone(result) + self.assertTrue(result.endswith(".custom")) + + # Clean up backup + os.unlink(result) + finally: + os.unlink(original_file) + + def test_create_backup_file_nonexistent(self): + """Test creating backup of nonexistent file""" + result = haven_vlm_utility.create_backup_file("nonexistent_file.txt") + self.assertIsNone(result) + + def test_create_backup_file_permission_error(self): + """Test creating backup file with permission error""" + with patch('shutil.copy2', side_effect=PermissionError("Permission denied")): + with tempfile.NamedTemporaryFile(delete=False) as f: + original_file = f.name + + try: + result = haven_vlm_utility.create_backup_file(original_file) + self.assertIsNone(result) + finally: + os.unlink(original_file) + + +class TestDictionaryOperations(unittest.TestCase): + """Test cases for dictionary operations""" + + def test_merge_dictionaries_simple(self): + """Test simple dictionary merging""" + dict1 = {"a": 1, "b": 2} + dict2 = {"c": 3, "d": 4} + + result = haven_vlm_utility.merge_dictionaries(dict1, dict2) + + expected = {"a": 1, "b": 2, "c": 3, "d": 4} + self.assertEqual(result, expected) + + def test_merge_dictionaries_overwrite(self): + """Test dictionary merging with overwrite""" + dict1 = {"a": 1, "b": 2} + dict2 = {"b": 3, "c": 4} + + result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True) + + expected = {"a": 1, "b": 3, "c": 4} + self.assertEqual(result, expected) + + def test_merge_dictionaries_no_overwrite(self): + """Test dictionary merging without overwrite""" + dict1 = {"a": 1, "b": 2} + dict2 = {"b": 3, "c": 4} + + result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=False) + + expected = {"a": 1, "b": 2, "c": 4} + self.assertEqual(result, expected) + + def test_merge_dictionaries_nested(self): + """Test merging nested dictionaries""" + dict1 = {"a": 1, "b": {"x": 10, "y": 20}} + dict2 = {"c": 3, "b": {"y": 25, "z": 30}} + + result = haven_vlm_utility.merge_dictionaries(dict1, dict2, overwrite=True) + + expected = {"a": 1, "b": {"x": 10, "y": 25, "z": 30}, "c": 3} + self.assertEqual(result, expected) + + def test_merge_dictionaries_empty(self): + """Test merging with empty dictionaries""" + dict1 = {} + dict2 = {"a": 1, "b": 2} + + result = haven_vlm_utility.merge_dictionaries(dict1, dict2) + + self.assertEqual(result, dict2) + + +class TestListOperations(unittest.TestCase): + """Test cases for list operations""" + + def test_chunk_list_even_chunks(self): + """Test chunking list into even chunks""" + lst = [1, 2, 3, 4, 5, 6] + result = haven_vlm_utility.chunk_list(lst, 2) + + expected = [[1, 2], [3, 4], [5, 6]] + self.assertEqual(result, expected) + + def test_chunk_list_uneven_chunks(self): + """Test chunking list into uneven chunks""" + lst = [1, 2, 3, 4, 5] + result = haven_vlm_utility.chunk_list(lst, 2) + + expected = [[1, 2], [3, 4], [5]] + self.assertEqual(result, expected) + + def test_chunk_list_empty_list(self): + """Test chunking empty list""" + lst = [] + result = haven_vlm_utility.chunk_list(lst, 3) + + expected = [] + self.assertEqual(result, expected) + + def test_chunk_list_chunk_size_larger_than_list(self): + """Test chunking when chunk size is larger than list""" + lst = [1, 2, 3] + result = haven_vlm_utility.chunk_list(lst, 5) + + expected = [[1, 2, 3]] + self.assertEqual(result, expected) + + +class TestRetryOperations(unittest.TestCase): + """Test cases for retry operations""" + + def test_retry_operation_success_first_try(self): + """Test retry operation that succeeds on first try""" + def test_func(): + return "success" + + result = haven_vlm_utility.retry_operation(test_func) + + self.assertEqual(result, "success") + + def test_retry_operation_success_after_retries(self): + """Test retry operation that succeeds after some retries""" + call_count = 0 + + def test_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("Temporary error") + return "success" + + result = haven_vlm_utility.retry_operation(test_func, max_retries=3, delay=0.1) + + self.assertEqual(result, "success") + self.assertEqual(call_count, 3) + + def test_retry_operation_all_retries_fail(self): + """Test retry operation that fails all retries""" + def test_func(): + raise ValueError("Persistent error") + + result = haven_vlm_utility.retry_operation(test_func, max_retries=2, delay=0.1) + + self.assertIsNone(result) + + def test_retry_operation_with_arguments(self): + """Test retry operation with function arguments""" + def test_func(a, b, c=10): + return a + b + c + + result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, 1, 2, c=5) + + self.assertEqual(result, 8) + + def test_retry_operation_with_keyword_arguments(self): + """Test retry operation with keyword arguments""" + def test_func(**kwargs): + return kwargs.get('value', 0) + + result = haven_vlm_utility.retry_operation(test_func, max_retries=1, delay=0.1, value=42) + + self.assertEqual(result, 42) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file