security: gh security recs (#3368)

* change ALLOW_SIGNUP to default to false

* add 1.4.0 tag for OIDC docs

* new notes on security inline with security/policy review

* safer transport for external requests

* fix linter errors

* docs: Tidy up wording/formatting

* fix request errors

* whoops

* fix implementation with std lib

* format

* Remove check on netloc_parts. It only includes URL after any @

---------

Co-authored-by: boc-the-git <3479092+boc-the-git@users.noreply.github.com>
Co-authored-by: Brendan <b.oconnell14@gmail.com>
This commit is contained in:
Hayden
2024-04-02 10:04:42 -05:00
committed by GitHub
parent 737a370874
commit 2a3463b746
11 changed files with 180 additions and 54 deletions

View File

@@ -47,7 +47,7 @@ class AppSettings(BaseSettings):
GIT_COMMIT_HASH: str = "unknown"
ALLOW_SIGNUP: bool = True
ALLOW_SIGNUP: bool = False
# ===============================================
# Security Configuration

View File

@@ -0,0 +1,7 @@
from .transport import AsyncSafeTransport, ForcedTimeoutException, InvalidDomainError
__all__ = [
"AsyncSafeTransport",
"ForcedTimeoutException",
"InvalidDomainError",
]

View File

@@ -0,0 +1,78 @@
import ipaddress
import logging
import socket
import httpx
class ForcedTimeoutException(Exception):
"""
Raised when a request takes longer than the timeout value.
"""
...
class InvalidDomainError(Exception):
"""
Raised when a request is made to a local IP address.
"""
...
class AsyncSafeTransport(httpx.AsyncBaseTransport):
"""
A wrapper around the httpx transport class that enforces a timeout value
and that the request is not made to a local IP address.
"""
timeout: int = 15
def __init__(self, log: logging.Logger | None = None, **kwargs):
self.timeout = kwargs.pop("timeout", self.timeout)
self._wrapper = httpx.AsyncHTTPTransport(**kwargs)
self._log = log
async def handle_async_request(self, request):
# override timeout value for _all_ requests
request.extensions["timeout"] = httpx.Timeout(self.timeout, pool=self.timeout).as_dict()
# validate the request is not attempting to connect to a local IP
# This is a security measure to prevent SSRF attacks
ip: ipaddress.IPv4Address | ipaddress.IPv6Address | None = None
netloc = request.url.netloc.decode()
if ":" in netloc: # Either an IP, or a hostname:port combo
netloc_parts = netloc.split(":")
netloc = netloc_parts[0]
try:
ip = ipaddress.ip_address(netloc)
except ValueError:
if self._log:
self._log.debug(f"failed to parse ip for {netloc=} falling back to domain resolution")
pass
# Request is a domain or a hostname.
if not ip:
if self._log:
self._log.debug(f"resolving IP for domain: {netloc}")
ip_str = socket.gethostbyname(netloc)
ip = ipaddress.ip_address(ip_str)
if self._log:
self._log.debug(f"resolved IP for domain: {netloc} -> {ip}")
if ip.is_private:
if self._log:
self._log.warning(f"invalid request on local resource: {request.url} -> {ip}")
raise InvalidDomainError(f"invalid request on local resource: {request.url} -> {ip}")
return await self._wrapper.handle_async_request(request)
async def aclose(self):
await self._wrapper.aclose()

View File

@@ -5,7 +5,8 @@ from pathlib import Path
from httpx import AsyncClient, Response
from pydantic import UUID4
from mealie.pkgs import img
from mealie.pkgs import img, safehttp
from mealie.pkgs.safehttp.transport import AsyncSafeTransport
from mealie.schema.recipe.recipe import Recipe
from mealie.services._base_service import BaseService
@@ -29,12 +30,14 @@ async def largest_content_len(urls: list[str]) -> tuple[str, int]:
largest_url = ""
largest_len = 0
max_concurrency = 10
async def do(client: AsyncClient, url: str) -> Response:
return await client.head(url, headers={"User-Agent": _FIREFOX_UA})
async with AsyncClient() as client:
async with AsyncClient(transport=safehttp.AsyncSafeTransport()) as client:
tasks = [do(client, url) for url in urls]
responses: list[Response] = await gather_with_concurrency(10, *tasks, ignore_exceptions=True)
responses: list[Response] = await gather_with_concurrency(max_concurrency, *tasks, ignore_exceptions=True)
for response in responses:
len_int = int(response.headers.get("Content-Length", 0))
if len_int > largest_len:
@@ -101,42 +104,29 @@ class RecipeDataService(BaseService):
return image_path
@staticmethod
def _validate_image_url(url: str) -> bool:
# sourcery skip: invert-any-all, use-any
"""
Validates that the URL is of an allowed source and restricts certain sources to prevent
malicious images from being downloaded.
"""
invalid_domains = {"127.0.0.1", "localhost"}
for domain in invalid_domains:
if domain in url:
return False
return True
async def scrape_image(self, image_url) -> None:
async def scrape_image(self, image_url: str | dict[str, str] | list[str]) -> None:
self.logger.info(f"Image URL: {image_url}")
if not self._validate_image_url(image_url):
self.logger.error(f"Invalid image URL: {image_url}")
raise InvalidDomainError(f"Invalid domain: {image_url}")
image_url_str = ""
if isinstance(image_url, str): # Handles String Types
pass
image_url_str = image_url
elif isinstance(image_url, list): # Handles List Types
# Multiple images have been defined in the schema - usually different resolutions
# Typically would be in smallest->biggest order, but can't be certain so test each.
# 'Google will pick the best image to display in Search results based on the aspect ratio and resolution.'
image_url, _ = await largest_content_len(image_url)
image_url_str, _ = await largest_content_len(image_url)
elif isinstance(image_url, dict): # Handles Dictionary Types
for key in image_url:
if key == "url":
image_url = image_url.get("url")
image_url_str = image_url.get("url", "")
ext = image_url.split(".")[-1]
if not image_url_str:
raise ValueError(f"image url could not be parsed from input: {image_url}")
ext = image_url_str.split(".")[-1]
if ext not in img.IMAGE_EXTENSIONS:
ext = "jpg" # Guess the extension
@@ -144,9 +134,9 @@ class RecipeDataService(BaseService):
file_name = f"{str(self.recipe_id)}.{ext}"
file_path = Recipe.directory_from_id(self.recipe_id).joinpath("images", file_name)
async with AsyncClient() as client:
async with AsyncClient(transport=AsyncSafeTransport()) as client:
try:
r = await client.get(image_url, headers={"User-Agent": _FIREFOX_UA})
r = await client.get(image_url_str, headers={"User-Agent": _FIREFOX_UA})
except Exception:
self.logger.exception("Fatal Image Request Exception")
return None

View File

@@ -43,7 +43,7 @@ async def create_from_url(url: str, translator: Translator) -> tuple[Recipe, Scr
recipe_data_service = RecipeDataService(new_recipe.id)
try:
await recipe_data_service.scrape_image(new_recipe.image)
await recipe_data_service.scrape_image(new_recipe.image) # type: ignore
if new_recipe.name is None:
new_recipe.name = "Untitled"

View File

@@ -12,6 +12,7 @@ from w3lib.html import get_base_url
from mealie.core.root_logger import get_logger
from mealie.lang.providers import Translator
from mealie.pkgs import safehttp
from mealie.schema.recipe.recipe import Recipe, RecipeStep
from mealie.services.scraper.scraped_extras import ScrapedExtras
@@ -31,7 +32,7 @@ async def safe_scrape_html(url: str) -> str:
if the request takes longer than 15 seconds. This is used to mitigate
DDOS attacks from users providing a url with arbitrary large content.
"""
async with AsyncClient() as client:
async with AsyncClient(transport=safehttp.AsyncSafeTransport()) as client:
html_bytes = b""
async with client.stream("GET", url, timeout=SCRAPER_TIMEOUT, headers={"User-Agent": _FIREFOX_UA}) as resp:
start_time = time.time()