Replace insecure random with secrets module in cryptographic contexts (#3248)
* Initial plan * Security fixes: Replace insecure random with secrets module and improve SSL context Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> * Address code review feedback: fix POST method and add named constants Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> * Improve documentation for random number generation constants Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/test_security_fixes.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix: Handle path parameter in SSL fallback for download_image_by_url Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com> Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import secrets
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
@@ -139,6 +139,12 @@ class PKCS7Encoder:
|
|||||||
class Prpcrypt:
|
class Prpcrypt:
|
||||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||||
|
|
||||||
|
# 16位随机字符串的范围常量
|
||||||
|
# randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999](两端都包含,即包含0和8999999999999999)
|
||||||
|
# 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999](两端都包含)即16位数字
|
||||||
|
MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位)
|
||||||
|
RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位)
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
# self.key = base64.b64decode(key+"=")
|
# self.key = base64.b64decode(key+"=")
|
||||||
self.key = key
|
self.key = key
|
||||||
@@ -207,7 +213,9 @@ class Prpcrypt:
|
|||||||
"""随机生成16位字符串
|
"""随机生成16位字符串
|
||||||
@return: 16位字符串
|
@return: 16位字符串
|
||||||
"""
|
"""
|
||||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
return str(
|
||||||
|
secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE
|
||||||
|
).encode()
|
||||||
|
|
||||||
|
|
||||||
class WXBizJsonMsgCrypt:
|
class WXBizJsonMsgCrypt:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ def generate_random_string(length: int = 10) -> str:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
letters = string.ascii_letters + string.digits
|
letters = string.ascii_letters + string.digits
|
||||||
return "".join(random.choice(letters) for _ in range(length))
|
return "".join(secrets.choice(letters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
def calculate_image_md5(image_data: bytes) -> str:
|
def calculate_image_md5(image_data: bytes) -> str:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -54,7 +54,9 @@ class OTTSProvider:
|
|||||||
async def _generate_signature(self) -> str:
|
async def _generate_signature(self) -> str:
|
||||||
await self._sync_time()
|
await self._sync_time()
|
||||||
timestamp = int(time.time()) + self.time_offset
|
timestamp = int(time.time()) + self.time_offset
|
||||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
nonce = "".join(
|
||||||
|
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
|
||||||
|
)
|
||||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||||
|
|
||||||
|
|||||||
@@ -105,16 +105,31 @@ async def download_image_by_url(
|
|||||||
f.write(await resp.read())
|
f.write(await resp.read())
|
||||||
return path
|
return path
|
||||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||||
# 关闭SSL验证
|
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||||
|
logger.warning(
|
||||||
|
f"SSL certificate verification failed for {url}. "
|
||||||
|
"Disabling SSL verification (CERT_NONE) as a fallback. "
|
||||||
|
"This is insecure and exposes the application to man-in-the-middle attacks. "
|
||||||
|
"Please investigate and resolve certificate issues."
|
||||||
|
)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
ssl_context.set_ciphers("DEFAULT")
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
if post:
|
if post:
|
||||||
async with session.get(url, ssl=ssl_context) as resp:
|
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
|
||||||
|
if not path:
|
||||||
return save_temp_img(await resp.read())
|
return save_temp_img(await resp.read())
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(await resp.read())
|
||||||
|
return path
|
||||||
else:
|
else:
|
||||||
async with session.get(url, ssl=ssl_context) as resp:
|
async with session.get(url, ssl=ssl_context) as resp:
|
||||||
|
if not path:
|
||||||
return save_temp_img(await resp.read())
|
return save_temp_img(await resp.read())
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(await resp.read())
|
||||||
|
return path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -157,9 +172,19 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
|||||||
end="",
|
end="",
|
||||||
)
|
)
|
||||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||||
# 关闭SSL验证
|
# 关闭SSL验证(仅在证书验证失败时作为fallback)
|
||||||
|
logger.warning(
|
||||||
|
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"SSL certificate verification failed for {url}. "
|
||||||
|
"Falling back to unverified connection (CERT_NONE). "
|
||||||
|
"This is insecure and exposes the application to man-in-the-middle attacks. "
|
||||||
|
"Please investigate certificate issues with the remote server."
|
||||||
|
)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
ssl_context.set_ciphers("DEFAULT")
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||||
total_size = int(resp.headers.get("content-length", 0))
|
total_size = int(resp.headers.get("content-length", 0))
|
||||||
|
|||||||
151
tests/test_security_fixes.py
Normal file
151
tests/test_security_fixes.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Tests for security fixes - cryptographic random number generation and SSL context."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add project root to sys.path
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_wecom_crypto_uses_secrets():
|
||||||
|
"""Test that WXBizJsonMsgCrypt uses secrets module instead of random."""
|
||||||
|
from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import Prpcrypt
|
||||||
|
|
||||||
|
# Create an instance and test that random string generation works
|
||||||
|
prpcrypt = Prpcrypt(b"test_key_32_bytes_long_value!")
|
||||||
|
|
||||||
|
# Generate multiple random strings and verify they are different and valid
|
||||||
|
random_strings = [prpcrypt.get_random_str() for _ in range(10)]
|
||||||
|
|
||||||
|
# All strings should be 16 bytes long
|
||||||
|
assert all(len(s) == 16 for s in random_strings)
|
||||||
|
|
||||||
|
# All strings should be different (extremely high probability with cryptographic random)
|
||||||
|
assert len(set(random_strings)) == 10
|
||||||
|
|
||||||
|
# All strings should be numeric when decoded
|
||||||
|
for s in random_strings:
|
||||||
|
decoded = s.decode()
|
||||||
|
assert decoded.isdigit()
|
||||||
|
assert 1000000000000000 <= int(decoded) <= 9999999999999999
|
||||||
|
|
||||||
|
|
||||||
|
def test_wecomai_utils_uses_secrets():
|
||||||
|
"""Test that wecomai_utils uses secrets module for random string generation."""
|
||||||
|
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_utils import (
|
||||||
|
generate_random_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate multiple random strings and verify they are different
|
||||||
|
random_strings = [generate_random_string(10) for _ in range(20)]
|
||||||
|
|
||||||
|
# All strings should be 10 characters long
|
||||||
|
assert all(len(s) == 10 for s in random_strings)
|
||||||
|
|
||||||
|
# All strings should be alphanumeric
|
||||||
|
for s in random_strings:
|
||||||
|
assert s.isalnum()
|
||||||
|
|
||||||
|
# All strings should be different (extremely high probability with cryptographic random)
|
||||||
|
assert len(set(random_strings)) >= 19 # Allow for 1 collision in 20 (very unlikely)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_tts_signature_uses_secrets():
|
||||||
|
"""Test that Azure TTS signature generation uses secrets module."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from astrbot.core.provider.sources.azure_tts_source import OTTSProvider
|
||||||
|
|
||||||
|
# Create a provider with test config
|
||||||
|
config = {
|
||||||
|
"OTTS_SKEY": "test_secret_key",
|
||||||
|
"OTTS_URL": "https://example.com/api/tts",
|
||||||
|
"OTTS_AUTH_TIME": "https://example.com/api/time",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_nonce_generation():
|
||||||
|
async with OTTSProvider(config) as provider:
|
||||||
|
# Mock time sync to avoid actual API calls
|
||||||
|
provider.time_offset = 0
|
||||||
|
provider.last_sync_time = 9999999999
|
||||||
|
|
||||||
|
# Generate multiple signatures and extract nonces
|
||||||
|
signatures = []
|
||||||
|
for _ in range(10):
|
||||||
|
sig = await provider._generate_signature()
|
||||||
|
signatures.append(sig)
|
||||||
|
|
||||||
|
# Extract nonces (second field in signature format: timestamp-nonce-0-hash)
|
||||||
|
nonces = [sig.split("-")[1] for sig in signatures]
|
||||||
|
|
||||||
|
# All nonces should be 10 characters long
|
||||||
|
assert all(len(n) == 10 for n in nonces)
|
||||||
|
|
||||||
|
# All nonces should be alphanumeric (lowercase letters and digits)
|
||||||
|
for n in nonces:
|
||||||
|
assert all(c in "abcdefghijklmnopqrstuvwxyz0123456789" for c in n)
|
||||||
|
|
||||||
|
# All nonces should be different (cryptographic random ensures uniqueness)
|
||||||
|
assert len(set(nonces)) == 10
|
||||||
|
|
||||||
|
asyncio.run(test_nonce_generation())
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssl_context_fallback_explicit():
|
||||||
|
"""Test that SSL context fallback is properly configured."""
|
||||||
|
# This test verifies the SSL context configuration
|
||||||
|
# We can't easily test the full io.py functions without network calls,
|
||||||
|
# but we can verify that ssl.CERT_NONE and check_hostname=False are valid settings
|
||||||
|
|
||||||
|
# Create a context similar to what's used in io.py
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Verify the settings are applied correctly
|
||||||
|
assert ssl_context.check_hostname is False
|
||||||
|
assert ssl_context.verify_mode == ssl.CERT_NONE
|
||||||
|
|
||||||
|
# This configuration should work but is intentionally insecure for fallback
|
||||||
|
# The actual code only uses this when certificate validation fails
|
||||||
|
|
||||||
|
|
||||||
|
def test_io_module_has_ssl_imports():
|
||||||
|
"""Verify that io.py properly imports ssl module."""
|
||||||
|
from astrbot.core.utils import io
|
||||||
|
|
||||||
|
# Check that ssl is available in the module
|
||||||
|
assert hasattr(io, "ssl")
|
||||||
|
|
||||||
|
# Check that CERT_NONE constant is accessible
|
||||||
|
assert hasattr(io.ssl, "CERT_NONE")
|
||||||
|
|
||||||
|
|
||||||
|
def test_secrets_module_randomness_quality():
|
||||||
|
"""Test that secrets module provides high-quality randomness."""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
# Generate a large set of random numbers
|
||||||
|
random_numbers = [secrets.randbelow(100) for _ in range(1000)]
|
||||||
|
|
||||||
|
# Basic statistical test: should have good distribution
|
||||||
|
unique_values = len(set(random_numbers))
|
||||||
|
|
||||||
|
# With 1000 random numbers from 0-99, we should see most values at least once
|
||||||
|
# This is a very basic test - real cryptographic random should pass this easily
|
||||||
|
assert unique_values >= 60 # Should see at least 60 different values out of 100
|
||||||
|
|
||||||
|
# Test secrets.choice for string generation
|
||||||
|
chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
random_chars = [secrets.choice(chars) for _ in range(1000)]
|
||||||
|
|
||||||
|
# Should have good character distribution
|
||||||
|
unique_chars = len(set(random_chars))
|
||||||
|
assert unique_chars >= 20 # Should see at least 20 different characters
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user