"""S3 compatible OBS(Object Storage Service) client implementation with AWS V4 signature authentication."""
__all__ = ["S3CompatibleObsClient"]
import hashlib
import hmac
import json
import logging
import os
import urllib.parse
from datetime import datetime
import aiohttp
from pydantic import ConfigDict, PrivateAttr
from deepinsight.config.config import Config
from deepinsight.config.file_storage_config import ConfigS3
from deepinsight.utils.file_storage.base import BaseFileStorage, StorageError, StorageOp
logger = logging.getLogger(__name__)
class S3CompatibleObsClient(BaseFileStorage):
"""S3 compatible OBS(Object Storage Service) client implementation with AWS V4 signature authentication."""
model_config = ConfigDict(arbitrary_types_allowed=True)
config: ConfigS3
_session: aiohttp.ClientSession | None = PrivateAttr(None)
_warn_delete_always_allow_unexist: bool = PrivateAttr(True)
async def __aenter__(self):
if not self._session:
self._session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session:
await self._session.close()
@staticmethod
def _parse_list_buckets_xml(xml_content: str) -> list[str]:
"""Parse XML response from list buckets operation."""
import xml.etree.ElementTree as ElementTree
try:
root = ElementTree.fromstring(xml_content)
buckets = []
for bucket in root.findall(".//{http://s3.amazonaws.com/doc/2006-03-01/}Bucket"):
name_elem = bucket.find("{http://s3.amazonaws.com/doc/2006-03-01/}Name")
if name_elem is not None and name_elem.text:
buckets.append(name_elem.text)
if not buckets:
for bucket in root.findall(".//{*}Bucket"):
name_elem = bucket.find("{*}Name")
if name_elem is not None and name_elem.text:
buckets.append(name_elem.text)
return buckets
except ElementTree.ParseError as e:
logger.error(f"Failed to parse XML response: {e}")
raise StorageError(StorageOp.LIST, "", reason=StorageError.Reason.OTHER) from e
@staticmethod
def _parse_list_objects_xml(xml_content: str, bucket: str, prefix: str = None) -> list[str]:
"""Parse XML response from list objects operation."""
import xml.etree.ElementTree as ElementTree
try:
root = ElementTree.fromstring(xml_content)
files = []
for content in root.findall(".//{http://s3.amazonaws.com/doc/2006-03-01/}Contents"):
key_elem = content.find("{http://s3.amazonaws.com/doc/2006-03-01/}Key")
if key_elem is not None and key_elem.text:
file_name = key_elem.text
if prefix is None or file_name.startswith(prefix):
files.append(file_name)
if not files:
for content in root.findall(".//{*}Contents"):
key_elem = content.find("{*}Key")
if key_elem is not None and key_elem.text:
file_name = key_elem.text
if prefix is None or file_name.startswith(prefix):
files.append(file_name)
return files
except ElementTree.ParseError as e:
logger.error(f"Failed to parse XML response: {e}")
raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.OTHER) from e
@classmethod
def from_config(cls, config: Config) -> "S3CompatibleObsClient":
s3_conf = config.file_storage.s3
if not s3_conf:
endpoint = os.getenv("S3_ENDPOINT")
ak = os.getenv("S3_AK")
sk = os.getenv("S3_SK")
if not all([endpoint, ak, sk]):
raise RuntimeError("For S3 storage, config file_storage.s3 or environ 'S3_ENDPOINT', 'S3_AK' and "
"'S3_SK' is necessary.")
s3_conf = ConfigS3(endpoint=endpoint, ak=ak, sk=sk)
return cls(config=s3_conf)
async def bucket_create(self, bucket: str, *, exist_ok: bool = False) -> bool:
"""Create a new bucket."""
url = self._request_url(bucket)
try:
status, content, headers = await self._make_request("HEAD", url)
if status == 200:
if exist_ok:
return False
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.ALREADY_EXISTS)
elif status == 400:
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NAME_ILLEGAL)
elif status == 404:
pass
else:
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.OTHER)
except aiohttp.ClientError as e:
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NETWORK) from e
try:
status, content, headers = await self._make_request("PUT", url)
if status not in [200, 201]:
error_text = content.decode("utf-8", errors="ignore")
logger.error(f"Failed to create bucket {bucket}: {error_text}")
if status == 403:
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.PERMISSION)
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.OTHER)
except aiohttp.ClientError as e:
logger.error(f"Network error creating bucket: {e}")
raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NETWORK) from e
return True
async def list_buckets(self) -> list[str]:
try:
status, content, headers = await self._make_request("GET", self._request_url())
if status != 200:
error_text = content.decode("utf-8", errors="ignore")
logger.error(f"Failed to list buckets: {error_text}")
if status == 403:
raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.PERMISSION)
raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.OTHER)
return self._parse_list_buckets_xml(content.decode("utf-8"))
except aiohttp.ClientError as e:
logger.error(f"Network error listing buckets: {e}")
raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.NETWORK) from e
async def list_files(self, bucket: str, prefix: str = None) -> list[str]:
params = {}
if prefix:
params["prefix"] = prefix
params["delimiter"] = "/"
url = self._request_url(bucket)
if params:
url += "?" + "&".join(f"{k}={urllib.parse.quote(v, safe='')}"
for k, v in sorted(params.items(), key=lambda i: i[0]))
try:
status, content, headers = await self._make_request("GET", url)
if status == 404:
raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.BUCKET_NOT_FOUND)
elif status != 200:
error_text = content.decode("utf-8", errors="ignore")
logger.error(f"Failed to list files in bucket {bucket}: {error_text}")
if status == 403:
raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.PERMISSION)
raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.OTHER)
return self._parse_list_objects_xml(content.decode("utf-8"), bucket, prefix)
except aiohttp.ClientError as e:
logger.error(f"Network error listing files: {e}")
raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.NETWORK) from e
async def file_add(self, bucket: str, filename: str, content: bytes) -> None:
url = self._request_url(bucket, filename)
try:
status, content_resp, headers = await self._make_request("PUT", url, data=content)
if status == 404:
raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.BUCKET_NOT_FOUND)
elif status not in [200, 201]:
error_text = content_resp.decode("utf-8", errors="ignore")
logger.error(f"Failed to add file {filename} to bucket {bucket}: {error_text}")
if status == 403:
raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.PERMISSION)
raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.OTHER)
except aiohttp.ClientError as e:
logger.error(f"Network error adding file: {e}")
raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.NETWORK) from e
async def file_delete(self, bucket: str, filename: str, allow_not_exists: bool = True) -> None:
url = self._request_url(bucket, filename)
if not allow_not_exists:
if self._warn_delete_always_allow_unexist:
self._warn_delete_always_allow_unexist = False
logger.warning(f"Storage implementation {type(self).__name__} always allows delete unexist files.")
try:
status, content, headers = await self._make_request("DELETE", url)
if status not in [200, 204]:
error_text = content.decode("utf-8", errors="ignore")
logger.error(f"Failed to delete file {filename} from bucket {bucket}: {error_text}")
if status == 403:
raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.PERMISSION)
raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.OTHER)
except aiohttp.ClientError as e:
logger.error(f"Network error deleting file: {e}")
raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.NETWORK) from e
async def file_get(self, bucket: str, filename: str) -> bytes:
try:
status, content, headers = await self._make_request("GET", self._request_url(bucket, filename))
if status == 404:
raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.FILE_NOT_FOUND)
elif status == 403:
raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.PERMISSION)
elif status != 200:
error_text = content.decode("utf-8", errors="ignore")
if "InvalidBucketName" in error_text:
raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.BUCKET_NOT_FOUND)
logger.error(f"Failed to get file {filename} from bucket {bucket}: {error_text}")
raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.OTHER)
return content
except aiohttp.ClientError as e:
logger.error(f"Network error getting file: {e}")
raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.NETWORK) from e
async def bucket_allow_anonymous_get(self, bucket: str) -> None:
url = self._request_url(bucket) + "?policy="
try:
status, content, headers = await self._make_request(
"PUT", url, headers={"Content-Type": "application/json"},
data=json.dumps({
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": "*",
"Action": "s3:GetObject",
"Resource": f"arn:aws:s3:::{bucket}/*"
}
]
}).encode("utf8")
)
if status == 404:
raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.FILE_NOT_FOUND)
elif status == 403:
raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.PERMISSION)
elif status not in {200, 204}:
error_text = content.decode("utf-8", errors="ignore")
logger.error(f"Failed to set bucket {bucket} policy to allow anonymous get: {error_text}")
raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.OTHER)
except aiohttp.ClientError as e:
logger.error(f"Network error setting bucket policy to allow anonymous get: {e}")
raise StorageError(StorageOp.CONFIG, bucket, reason=StorageError.Reason.NETWORK) from e
def _get_aws_v4_signature(self, method: str, path: str, headers: dict, payload: bytes = b'',
query: str = '') -> dict:
"""Generate AWS V4 signature for authentication."""
parsed_url = urllib.parse.urlparse(self.config.endpoint)
host = parsed_url.netloc
service = "s3"
region = "us-east-1"
algorithm = "AWS4-HMAC-SHA256"
now = datetime.utcnow()
amz_date = now.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now.strftime("%Y%m%d")
canonical_uri = path
canonical_querystring = query
canonical_headers = f"host:{host}\nx-amz-date:{amz_date}\n"
signed_headers = "host;x-amz-date"
payload_hash = hashlib.sha256(payload).hexdigest()
canonical_request = (f"{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}")
credential_scope = f"{date_stamp}/{region}/{service}/aws4_request"
string_to_sign = (f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}")
signing_key = self._aws_v4_signature_key(date_stamp, region, service)
signature = hmac.new(signing_key, string_to_sign.encode(), hashlib.sha256).hexdigest()
authorization_header = (f"{algorithm} Credential={self.config.ak.get_secret_value()}/{credential_scope},"
f" SignedHeaders={signed_headers}, Signature={signature}")
auth_headers = {
"Authorization": authorization_header,
"x-amz-date": amz_date,
"x-amz-content-sha256": payload_hash
}
auth_headers.update(headers)
return auth_headers
def _aws_v4_signature_key(self, date_stamp: str, region: str, service: str) -> bytes:
"""Get AWS V4 signature key."""
key = f"AWS4{self.config.sk.get_secret_value()}".encode()
k_date = hmac.new(key, date_stamp.encode(), hashlib.sha256).digest()
k_region = hmac.new(k_date, region.encode(), hashlib.sha256).digest()
k_service = hmac.new(k_region, service.encode(), hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
return k_signing
def _request_url(self, bucket: str = None, key: str = None) -> str:
base_url = self.config.endpoint.rstrip("/")
if bucket:
bucket = urllib.parse.quote(bucket, safe="~/")
key = urllib.parse.quote(key, safe="~/") if key else key
return f"{base_url}/{bucket}/{key}" if key else f"{base_url}/{bucket}"
return f"{base_url}/"
async def _make_request(self, method: str, url: str, headers: dict = None,
data: bytes = None) -> tuple[int, bytes, dict]:
headers = headers or {}
if data:
headers.setdefault("Content-Type", "application/octet-stream")
parsed_url = urllib.parse.urlparse(url)
path = parsed_url.path
query = parsed_url.query
auth_headers = self._get_aws_v4_signature(method, path, headers, data or b'', query)
if not self._session:
self._session = aiohttp.ClientSession(trust_env=True)
async with self._session.request(method, url, headers=auth_headers, data=data) as response:
content = await response.read()
return response.status, content, dict(response.headers)