diff --git a/gotelegram-bot/bot.py b/gotelegram-bot/bot.py index 014a8ef..08ccfc6 100644 --- a/gotelegram-bot/bot.py +++ b/gotelegram-bot/bot.py @@ -402,6 +402,43 @@ def template_display_name(template_id: str) -> str: return template_id +def pro_template_map(context: ContextTypes.DEFAULT_TYPE) -> Dict[str, str]: + """Return the short callback key -> template id map for this chat.""" + mapping = context.user_data.setdefault("pro_template_map", {}) + if not isinstance(mapping, dict): + mapping = {} + context.user_data["pro_template_map"] = mapping + return mapping + + +def resolve_pro_template_id(context: ContextTypes.DEFAULT_TYPE, key_or_id: str) -> str: + """Resolve a short Telegram callback key back to the real template id.""" + mapped = pro_template_map(context).get(key_or_id) + if mapped: + return str(mapped) + + catalog = load_json(TEMPLATES_CATALOG) or {} + for cat in catalog.get("categories", []): + for tpl in cat.get("templates", []): + template_id = str(tpl.get("id", "")) + if hashlib.sha1(template_id.encode("utf-8")).hexdigest()[:12] == key_or_id: + return template_id + + return str(key_or_id) + + +def pro_template_key_for_id(context: ContextTypes.DEFAULT_TYPE, template_id: str) -> str: + """Store a template id behind a short key that fits Telegram callback limits.""" + mapping = pro_template_map(context) + template_id = str(template_id) + for key, stored_id in mapping.items(): + if stored_id == template_id: + return str(key) + key = hashlib.sha1(template_id.encode("utf-8")).hexdigest()[:12] + mapping[key] = template_id + return key + + async def safe_edit_message( query, text: str, @@ -449,9 +486,11 @@ async def check_service_status(service: str) -> bool: async def get_telemt_version() -> str: """Get telemt version.""" - code, stdout, _ = await sh("telemt", "-v") - if code == 0: - return stdout.strip().split()[-1] if stdout else "unknown" + for command in ("telemt", "/usr/local/bin/telemt", "/usr/bin/telemt"): + for args in (("--version",), ("-V",)): + code, stdout, _ = await sh(command, *args, timeout=5) + if code == 0 and stdout.strip(): + return stdout.strip().split()[-1] return "unknown" @@ -1258,10 +1297,11 @@ async def cb_pro_category(update: Update, context: ContextTypes.DEFAULT_TYPE) -> buttons = [] for tpl in templates: + key = pro_template_key_for_id(context, tpl["id"]) buttons.append( [ InlineKeyboardButton( - f"🎨 {tpl['name']}", callback_data=f"pro_tpl_{tpl['id']}" + f"🎨 {tpl['name']}", callback_data=f"pro_tpl_{key}" ) ] ) @@ -1276,7 +1316,8 @@ async def cb_pro_template(update: Update, context: ContextTypes.DEFAULT_TYPE) -> """Show template preview and confirm.""" query = update.callback_query data = query.data - tpl_id = data.removeprefix("pro_tpl_") + tpl_key = data.removeprefix("pro_tpl_") + tpl_id = resolve_pro_template_id(context, tpl_key) await query.answer() @@ -1315,7 +1356,7 @@ async def cb_pro_template(update: Update, context: ContextTypes.DEFAULT_TYPE) -> buttons = [ [ InlineKeyboardButton( - "✅ Install", callback_data=f"pro_confirm_{tpl_id}" + "✅ Install", callback_data=f"pro_confirm_{pro_template_key_for_id(context, tpl_id)}" ) ], [InlineKeyboardButton("« Back", callback_data="install_mode_pro")], @@ -1341,7 +1382,8 @@ async def cb_pro_confirm(update: Update, context: ContextTypes.DEFAULT_TYPE) -> """ query = update.callback_query data = query.data - tpl_id = data.removeprefix("pro_confirm_") + tpl_key = data.removeprefix("pro_confirm_") + tpl_id = resolve_pro_template_id(context, tpl_key) await query.answer() @@ -2570,9 +2612,7 @@ async def cb_menu_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> await safe_edit_message(query,"⏳ Checking for telemt updates...") - # Get current version - cur_code, cur_out, _ = await sh("telemt", "--version") - current = cur_out.strip() if cur_code == 0 else "unknown" + current = await get_telemt_version() # Check latest release from GitHub code, stdout, stderr = await sh( @@ -2586,7 +2626,7 @@ async def cb_menu_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> try: release = json.loads(stdout) latest = release.get("tag_name", "unknown") - if latest == current: + if latest.lstrip("v") == current.lstrip("v"): text = f"✅ telemt is already up to date ({html.escape(current)})" else: text = ( diff --git a/tests/test_bot_features.py b/tests/test_bot_features.py new file mode 100644 index 0000000..1d2e21b --- /dev/null +++ b/tests/test_bot_features.py @@ -0,0 +1,84 @@ +import json +import re +import unittest +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +BOT_PATH = ROOT / "gotelegram-bot" / "bot.py" +CATALOG_PATH = ROOT / "templates_catalog.json" + + +class BotFeatureTests(unittest.TestCase): + def test_catalog_contains_template_ids_that_break_raw_callback_data(self): + catalog = json.loads(CATALOG_PATH.read_text(encoding="utf-8")) + raw_lengths = [ + len(f"pro_tpl_{tpl['id']}".encode("utf-8")) + for cat in catalog.get("categories", []) + for tpl in cat.get("templates", []) + ] + + self.assertTrue(any(length > 64 for length in raw_lengths)) + + def test_bot_uses_short_template_callback_keys(self): + source = BOT_PATH.read_text(encoding="utf-8") + + self.assertNotIn("callback_data=f\"pro_tpl_{tpl['id']}\"", source) + self.assertNotIn('callback_data=f"pro_confirm_{tpl_id}"', source) + self.assertIn("pro_template_map", source) + self.assertIn("resolve_pro_template_id", source) + + def test_template_callbacks_are_restart_safe_hashes(self): + source = BOT_PATH.read_text(encoding="utf-8") + category_body = re.search( + r"async def cb_pro_category\(.*?(?=\n\n(?:async )?def |\n\n#)", + source, + flags=re.S, + ) + resolve_body = re.search( + r"def resolve_pro_template_id\(.*?(?=\n\n(?:async )?def |\n\n#)", + source, + flags=re.S, + ) + self.assertIsNotNone(category_body) + self.assertIsNotNone(resolve_body) + + self.assertIn('pro_template_key_for_id(context, tpl["id"])', category_body.group(0)) + self.assertNotIn("enumerate(templates)", category_body.group(0)) + self.assertNotIn("mapping.clear()", category_body.group(0)) + self.assertIn("load_json(TEMPLATES_CATALOG)", resolve_body.group(0)) + self.assertIn("hashlib.sha1", resolve_body.group(0)) + + def test_telemt_version_checks_systemd_path_fallbacks(self): + source = BOT_PATH.read_text(encoding="utf-8") + version_body = re.search( + r"async def get_telemt_version\(\).*?(?=\n\n(?:async )?def |\n\n#)", + source, + flags=re.S, + ) + self.assertIsNotNone(version_body) + body = version_body.group(0) + + self.assertIn('"--version"', body) + self.assertIn('"-V"', body) + self.assertIn('"/usr/local/bin/telemt"', body) + self.assertIn("for command in", body) + self.assertIn("for args in", body) + self.assertNotIn('"-v"', body) + + def test_telemt_update_menu_reuses_version_helper(self): + source = BOT_PATH.read_text(encoding="utf-8") + update_body = re.search( + r"async def cb_menu_update\(.*?(?=\n\n(?:async )?def |\n\n#)", + source, + flags=re.S, + ) + self.assertIsNotNone(update_body) + body = update_body.group(0) + + self.assertIn("await get_telemt_version()", body) + self.assertNotIn('sh("telemt", "--version")', body) + + +if __name__ == "__main__": + unittest.main()