diff --git a/Cargo.toml b/Cargo.toml index 1ac1a62e..1f70db08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "6.0.13" +version = "7.0.0" edition = "2021" authors = ["Dongri Jin "] license = "MIT" @@ -17,7 +17,7 @@ default-tls = ["reqwest/default-tls", "tokio-tungstenite/native-tls"] [dependencies.reqwest] version = "0.12" default-features = false -features = ["charset", "http2", "json", "multipart", "socks"] +features = ["charset", "http2", "json", "multipart", "socks", "stream"] [dependencies.tokio] version = "1" diff --git a/README.md b/README.md index 58996582..13fe3f67 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Cargo.toml ```toml [dependencies] -openai-api-rs = "6.0.13" +openai-api-rs = "7.0.0" ``` ## Usage diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 635add58..3556de08 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O_MINI; use std::env; diff --git a/examples/chat_completion_stream.rs b/examples/chat_completion_stream.rs new file mode 100644 index 00000000..010e3f6a --- /dev/null +++ b/examples/chat_completion_stream.rs @@ -0,0 +1,44 @@ +use futures_util::StreamExt; +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::chat_completion::chat_completion_stream::{ + ChatCompletionStreamRequest, ChatCompletionStreamResponse, +}; +use openai_api_rs::v1::chat_completion::{self}; +use openai_api_rs::v1::common::GPT4_O_MINI; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; + + let req = ChatCompletionStreamRequest::new( + GPT4_O_MINI.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("What is bitcoin?")), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + + let mut result = client.chat_completion_stream(req).await?; + while let Some(response) = result.next().await { + match response.clone() { + ChatCompletionStreamResponse::ToolCall(toolcalls) => { + println!("Tool Call: {:?}", toolcalls); + } + ChatCompletionStreamResponse::Content(content) => { + println!("Content: {:?}", content); + } + ChatCompletionStreamResponse::Done => { + println!("Done"); + } + } + } + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example chat_completion_stream diff --git a/examples/function_call.rs b/examples/function_call.rs index 3935599e..b18769f6 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -1,5 +1,10 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::{ + chat_completion::ChatCompletionRequest, ChatCompletionMessage, +}; +use openai_api_rs::v1::chat_completion::{ + Content, FinishReason, MessageRole, Tool, ToolChoiceType, ToolType, +}; use openai_api_rs::v1::common::GPT4_O; use openai_api_rs::v1::types; use serde::{Deserialize, Serialize}; @@ -32,16 +37,16 @@ async fn main() -> Result<(), Box> { let req = ChatCompletionRequest::new( GPT4_O.to_string(), - vec![chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::user, - content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), + vec![ChatCompletionMessage { + role: MessageRole::user, + content: Content::Text(String::from("What is the price of Ethereum?")), name: None, tool_calls: None, tool_call_id: None, }], ) - .tools(vec![chat_completion::Tool { - r#type: chat_completion::ToolType::Function, + .tools(vec![Tool { + r#type: ToolType::Function, function: types::Function { name: String::from("get_coin_price"), description: Some(String::from("Get the price of a cryptocurrency")), @@ -52,7 +57,7 @@ async fn main() -> Result<(), Box> { }, }, }]) - .tool_choice(chat_completion::ToolChoiceType::Auto); + .tool_choice(ToolChoiceType::Auto); // debug request json // let serialized = serde_json::to_string(&req).unwrap(); @@ -65,14 +70,14 @@ async fn main() -> Result<(), Box> { println!("No finish_reason"); println!("{:?}", result.choices[0].message.content); } - Some(chat_completion::FinishReason::stop) => { + Some(FinishReason::stop) => { println!("Stop"); println!("{:?}", result.choices[0].message.content); } - Some(chat_completion::FinishReason::length) => { + Some(FinishReason::length) => { println!("Length"); } - Some(chat_completion::FinishReason::tool_calls) => { + Some(FinishReason::tool_calls) => { println!("ToolCalls"); #[derive(Deserialize, Serialize)] struct Currency { @@ -90,10 +95,10 @@ async fn main() -> Result<(), Box> { } } } - Some(chat_completion::FinishReason::content_filter) => { + Some(FinishReason::content_filter) => { println!("ContentFilter"); } - Some(chat_completion::FinishReason::null) => { + Some(FinishReason::null) => { println!("Null"); } } diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 901d5d7f..a18d355c 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O; use openai_api_rs::v1::types; use serde::{Deserialize, Serialize}; diff --git a/examples/openrouter.rs b/examples/openrouter.rs index 5295bf41..79b8ec01 100644 --- a/examples/openrouter.rs +++ b/examples/openrouter.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O_MINI; use std::env; diff --git a/examples/openrouter_reasoning.rs b/examples/openrouter_reasoning.rs index 9bfac3dd..9dd3c125 100644 --- a/examples/openrouter_reasoning.rs +++ b/examples/openrouter_reasoning.rs @@ -1,7 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{ - self, ChatCompletionRequest, Reasoning, ReasoningEffort, ReasoningMode, -}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self, Reasoning, ReasoningEffort, ReasoningMode}; use std::env; #[tokio::main] diff --git a/examples/vision.rs b/examples/vision.rs index 7bad362b..67c9af5d 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O; use std::env; diff --git a/src/v1/api.rs b/src/v1/api.rs index ce0b1c4e..8ff11652 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -7,7 +7,10 @@ use crate::v1::audio::{ AudioTranslationRequest, AudioTranslationResponse, }; use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse}; -use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::v1::chat_completion::chat_completion_stream::{ + ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse, +}; use crate::v1::common; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; @@ -39,11 +42,12 @@ use crate::v1::run::{ use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject}; use bytes::Bytes; +use futures_util::Stream; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::multipart::{Form, Part}; use reqwest::{Client, Method, Response}; use serde::Serialize; -use serde_json::Value; +use serde_json::{to_value, Value}; use url::Url; use std::error::Error; @@ -334,6 +338,40 @@ impl OpenAIClient { self.post("chat/completions", &req).await } + pub async fn chat_completion_stream( + &mut self, + req: ChatCompletionStreamRequest, + ) -> Result, APIError> { + let mut payload = to_value(&req).map_err(|err| APIError::CustomError { + message: format!("Failed to serialize request: {}", err), + })?; + + if let Some(obj) = payload.as_object_mut() { + obj.insert("stream".into(), Value::Bool(true)); + } + + let request = self.build_request(Method::POST, "chat/completions").await; + let request = request.json(&payload); + let response = request.send().await?; + + if response.status().is_success() { + Ok(ChatCompletionStream { + response: Box::pin(response.bytes_stream()), + buffer: String::new(), + first_chunk: true, + }) + } else { + let error_text = response + .text() + .await + .unwrap_or_else(|_| String::from("Unknown error")); + + Err(APIError::CustomError { + message: error_text, + }) + } + } + pub async fn audio_transcription( &mut self, req: AudioTranscriptionRequest, diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion/chat_completion.rs similarity index 51% rename from src/v1/chat_completion.rs rename to src/v1/chat_completion/chat_completion.rs index 09fdc0f0..2c56287d 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion/chat_completion.rs @@ -1,44 +1,13 @@ -use super::{common, types}; -use crate::impl_builder_methods; - -use serde::de::{self, MapAccess, SeqAccess, Visitor}; -use serde::ser::SerializeMap; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use crate::v1::chat_completion::{ChatCompletionChoice, Reasoning, Tool, ToolChoiceType}; +use crate::v1::common; +use crate::{ + impl_builder_methods, + v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage}, +}; + +use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; -use std::fmt; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] -pub enum ToolChoiceType { - None, - Auto, - Required, - ToolChoice { tool: Tool }, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum ReasoningEffort { - Low, - Medium, - High, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(untagged)] -pub enum ReasoningMode { - Effort { effort: ReasoningEffort }, - MaxTokens { max_tokens: i64 }, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Reasoning { - #[serde(flatten)] - pub mode: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub exclude: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub enabled: Option, -} #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChatCompletionRequest { @@ -53,8 +22,6 @@ pub struct ChatCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, @@ -93,7 +60,6 @@ impl ChatCompletionRequest { messages, temperature: None, top_p: None, - stream: None, n: None, response_format: None, stop: None, @@ -118,7 +84,6 @@ impl_builder_methods!( top_p: f64, n: i64, response_format: Value, - stream: bool, stop: Vec, max_tokens: i64, presence_penalty: f64, @@ -133,154 +98,6 @@ impl_builder_methods!( transforms: Vec ); -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum MessageRole { - user, - system, - assistant, - function, - tool, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Content { - Text(String), - ImageUrl(Vec), -} - -impl serde::Serialize for Content { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match *self { - Content::Text(ref text) => { - if text.is_empty() { - serializer.serialize_none() - } else { - serializer.serialize_str(text) - } - } - Content::ImageUrl(ref image_url) => image_url.serialize(serializer), - } - } -} - -impl<'de> Deserialize<'de> for Content { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct ContentVisitor; - - impl<'de> Visitor<'de> for ContentVisitor { - type Value = Content; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a valid content type") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - Ok(Content::Text(value.to_string())) - } - - fn visit_seq(self, seq: A) -> Result - where - A: SeqAccess<'de>, - { - let image_urls: Vec = - Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; - Ok(Content::ImageUrl(image_urls)) - } - - fn visit_map(self, map: M) -> Result - where - M: MapAccess<'de>, - { - let image_urls: Vec = - Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; - Ok(Content::ImageUrl(image_urls)) - } - - fn visit_none(self) -> Result - where - E: de::Error, - { - Ok(Content::Text(String::new())) - } - - fn visit_unit(self) -> Result - where - E: de::Error, - { - Ok(Content::Text(String::new())) - } - } - - deserializer.deserialize_any(ContentVisitor) - } -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum ContentType { - text, - image_url, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub struct ImageUrlType { - pub url: String, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub struct ImageUrl { - pub r#type: ContentType, - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub image_url: Option, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct ChatCompletionMessage { - pub role: MessageRole, - pub content: Content, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct ChatCompletionMessageForResponse { - pub role: MessageRole, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct ChatCompletionChoice { - pub index: i64, - pub message: ChatCompletionMessageForResponse, - pub finish_reason: Option, - pub finish_details: Option, -} - #[derive(Debug, Deserialize, Serialize)] pub struct ChatCompletionResponse { pub id: Option, @@ -292,73 +109,10 @@ pub struct ChatCompletionResponse { pub system_fingerprint: Option, } -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -#[allow(non_camel_case_types)] -pub enum FinishReason { - stop, - length, - content_filter, - tool_calls, - null, -} - -#[derive(Debug, Deserialize, Serialize)] -#[allow(non_camel_case_types)] -pub struct FinishDetails { - pub r#type: FinishReason, - pub stop: String, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct ToolCall { - pub id: String, - pub r#type: String, - pub function: ToolCallFunction, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct ToolCallFunction { - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option, -} - -fn serialize_tool_choice( - value: &Option, - serializer: S, -) -> Result -where - S: Serializer, -{ - match value { - Some(ToolChoiceType::None) => serializer.serialize_str("none"), - Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"), - Some(ToolChoiceType::Required) => serializer.serialize_str("required"), - Some(ToolChoiceType::ToolChoice { tool }) => { - let mut map = serializer.serialize_map(Some(2))?; - map.serialize_entry("type", &tool.r#type)?; - map.serialize_entry("function", &tool.function)?; - map.end() - } - None => serializer.serialize_none(), - } -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub struct Tool { - pub r#type: ToolType, - pub function: types::Function, -} - -#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum ToolType { - Function, -} - #[cfg(test)] mod tests { + use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode}; + use super::*; use serde_json::json; diff --git a/src/v1/chat_completion/chat_completion_stream.rs b/src/v1/chat_completion/chat_completion_stream.rs new file mode 100644 index 00000000..f5b3283e --- /dev/null +++ b/src/v1/chat_completion/chat_completion_stream.rs @@ -0,0 +1,361 @@ +use crate::v1::chat_completion::{Reasoning, Tool, ToolCall, ToolChoiceType}; +use crate::{ + impl_builder_methods, + v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage}, +}; + +use futures_util::Stream; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ChatCompletionStreamRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_tool_choice")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + /// Optional list of transforms to apply to the chat completion request. + /// + /// Transforms allow modifying the request before it's sent to the API, + /// enabling features like prompt rewriting, content filtering, or other + /// preprocessing steps. When None, no transforms are applied. + #[serde(skip_serializing_if = "Option::is_none")] + pub transforms: Option>, +} + +impl ChatCompletionStreamRequest { + pub fn new(model: String, messages: Vec) -> Self { + Self { + model, + messages, + temperature: None, + top_p: None, + n: None, + response_format: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + tools: None, + parallel_tool_calls: None, + tool_choice: None, + reasoning: None, + transforms: None, + } + } +} + +impl_builder_methods!( + ChatCompletionStreamRequest, + temperature: f64, + top_p: f64, + n: i64, + response_format: Value, + stop: Vec, + max_tokens: i64, + presence_penalty: f64, + frequency_penalty: f64, + logit_bias: HashMap, + user: String, + seed: i64, + tools: Vec, + parallel_tool_calls: bool, + tool_choice: ToolChoiceType, + reasoning: Reasoning, + transforms: Vec +); + +#[derive(Debug, Clone)] +pub enum ChatCompletionStreamResponse { + Content(String), + ToolCall(Vec), + Done, +} + +pub struct ChatCompletionStream> + Unpin> { + pub response: S, + pub buffer: String, + pub first_chunk: bool, +} + +impl ChatCompletionStream +where + S: Stream> + Unpin, +{ + fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> { + let carriage_idx = buffer.find("\r\n\r\n"); + let newline_idx = buffer.find("\n\n"); + + match (carriage_idx, newline_idx) { + (Some(r_idx), Some(n_idx)) => { + if r_idx <= n_idx { + Some((r_idx, 4)) + } else { + Some((n_idx, 2)) + } + } + (Some(r_idx), None) => Some((r_idx, 4)), + (None, Some(n_idx)) => Some((n_idx, 2)), + (None, None) => None, + } + } + + fn next_response_from_buffer(&mut self) -> Option { + while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) { + let event = self.buffer[..idx].to_owned(); + self.buffer = self.buffer[idx + delimiter_len..].to_owned(); + + let mut data_payload = String::new(); + for line in event.lines() { + let trimmed_line = line.trim_end_matches('\r'); + if let Some(content) = trimmed_line + .strip_prefix("data: ") + .or_else(|| trimmed_line.strip_prefix("data:")) + { + if !content.is_empty() { + if !data_payload.is_empty() { + data_payload.push('\n'); + } + data_payload.push_str(content); + } + } + } + + if data_payload.is_empty() { + continue; + } + + if data_payload == "[DONE]" { + return Some(ChatCompletionStreamResponse::Done); + } + + match serde_json::from_str::(&data_payload) { + Ok(json) => { + if let Some(delta) = json + .get("choices") + .and_then(|choices| choices.get(0)) + .and_then(|choice| choice.get("delta")) + { + if let Some(tool_call_response) = delta + .get("tool_calls") + .and_then(|tool_calls| tool_calls.as_array()) + .map(|tool_calls_array| { + tool_calls_array + .iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect::>() + }) + .filter(|tool_calls_vec| !tool_calls_vec.is_empty()) + .map(ChatCompletionStreamResponse::ToolCall) + { + return Some(tool_call_response); + } + + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + let output = content.replace("\\n", "\n"); + return Some(ChatCompletionStreamResponse::Content(output)); + } + } + } + Err(error) => { + eprintln!("Failed to parse SSE chunk as JSON: {}", error); + } + } + } + + None + } +} + +impl> + Unpin> Stream + for ChatCompletionStream +{ + type Item = ChatCompletionStreamResponse; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(response) = self.next_response_from_buffer() { + return Poll::Ready(Some(response)); + } + + match Pin::new(&mut self.as_mut().response).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let chunk_str = String::from_utf8_lossy(&chunk).to_string(); + + if self.first_chunk { + self.first_chunk = false; + } + self.buffer.push_str(&chunk_str); + } + Poll::Ready(Some(Err(error))) => { + eprintln!("Error in stream: {:?}", error); + return Poll::Ready(None); + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode}; + + use super::*; + use serde_json::json; + + #[test] + fn test_reasoning_effort_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::High, + }), + exclude: Some(false), + enabled: None, + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "effort": "high", + "exclude": false + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_max_tokens_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }), + exclude: None, + enabled: Some(true), + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "max_tokens": 2000, + "enabled": true + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_deserialization() { + let json_str = r#"{"effort": "medium", "exclude": true}"#; + let reasoning: Reasoning = serde_json::from_str(json_str).unwrap(); + + match reasoning.mode { + Some(ReasoningMode::Effort { effort }) => { + assert_eq!(effort, ReasoningEffort::Medium); + } + _ => panic!("Expected effort mode"), + } + assert_eq!(reasoning.exclude, Some(true)); + } + + #[test] + fn test_chat_completion_request_with_reasoning() { + let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + + req.reasoning = Some(Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::Low, + }), + exclude: None, + enabled: None, + }); + + let serialized = serde_json::to_value(&req).unwrap(); + assert_eq!(serialized["reasoning"]["effort"], "low"); + } + + #[test] + fn test_transforms_none_serialization() { + let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is completely omitted from JSON output + assert!(!serialised.as_object().unwrap().contains_key("transforms")); + } + + #[test] + fn test_transforms_some_serialization() { + let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is included as a proper JSON array + assert_eq!( + serialised["transforms"], + serde_json::json!(["transform1", "transform2"]) + ); + } + + #[test] + fn test_transforms_some_deserialization() { + let json_str = + r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#; + let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly populated with Some(vec) + assert_eq!( + req.transforms, + Some(vec!["transform1".to_string(), "transform2".to_string()]) + ); + } + + #[test] + fn test_transforms_none_deserialization() { + let json_str = r#"{"model": "gpt-4", "messages": []}"#; + let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly set to None when absent + assert_eq!(req.transforms, None); + } + + #[test] + fn test_transforms_builder_method() { + let transforms = vec!["transform1".to_string(), "transform2".to_string()]; + let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]) + .transforms(transforms.clone()); + // Verify that the transforms field is properly set through the builder method + assert_eq!(req.transforms, Some(transforms)); + } +} diff --git a/src/v1/chat_completion/mod.rs b/src/v1/chat_completion/mod.rs new file mode 100644 index 00000000..757e5052 --- /dev/null +++ b/src/v1/chat_completion/mod.rs @@ -0,0 +1,255 @@ +use crate::v1::types; +use serde::de::{self, MapAccess, SeqAccess, Visitor}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; + +#[allow(clippy::module_inception)] +pub mod chat_completion; +pub mod chat_completion_stream; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +pub enum ToolChoiceType { + None, + Auto, + Required, + ToolChoice { tool: Tool }, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum ReasoningMode { + Effort { effort: ReasoningEffort }, + MaxTokens { max_tokens: i64 }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Reasoning { + #[serde(flatten)] + pub mode: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub exclude: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub enum MessageRole { + user, + system, + assistant, + function, + tool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Content { + Text(String), + ImageUrl(Vec), +} + +impl serde::Serialize for Content { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match *self { + Content::Text(ref text) => { + if text.is_empty() { + serializer.serialize_none() + } else { + serializer.serialize_str(text) + } + } + Content::ImageUrl(ref image_url) => image_url.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for Content { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ContentVisitor; + + impl<'de> Visitor<'de> for ContentVisitor { + type Value = Content; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a valid content type") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(Content::Text(value.to_string())) + } + + fn visit_seq(self, seq: A) -> Result + where + A: SeqAccess<'de>, + { + let image_urls: Vec = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + let image_urls: Vec = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(Content::ImageUrl(image_urls)) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(Content::Text(String::new())) + } + } + + deserializer.deserialize_any(ContentVisitor) + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub enum ContentType { + text, + image_url, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub struct ImageUrlType { + pub url: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub struct ImageUrl { + pub r#type: ContentType, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ChatCompletionMessage { + pub role: MessageRole, + pub content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ChatCompletionMessageForResponse { + pub role: MessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatCompletionChoice { + pub index: i64, + pub message: ChatCompletionMessageForResponse, + pub finish_reason: Option, + pub finish_details: Option, +} + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub enum FinishReason { + stop, + length, + content_filter, + tool_calls, + null, +} + +#[derive(Debug, Deserialize, Serialize)] +#[allow(non_camel_case_types)] +pub struct FinishDetails { + pub r#type: FinishReason, + pub stop: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: ToolCallFunction, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ToolCallFunction { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +pub fn serialize_tool_choice( + value: &Option, + serializer: S, +) -> Result +where + S: Serializer, +{ + match value { + Some(ToolChoiceType::None) => serializer.serialize_str("none"), + Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"), + Some(ToolChoiceType::Required) => serializer.serialize_str("required"), + Some(ToolChoiceType::ToolChoice { tool }) => { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("type", &tool.r#type)?; + map.serialize_entry("function", &tool.function)?; + map.end() + } + None => serializer.serialize_none(), + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct Tool { + pub r#type: ToolType, + pub function: types::Function, +} + +#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ToolType { + Function, +}